|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch_geometric.data import HeteroData |
|
import numpy as np |
|
import pandas as pd |
|
import networkx as nx |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, classification_report, roc_curve |
|
from sklearn.model_selection import train_test_split |
|
from pathlib import Path |
|
from datetime import datetime |
|
from loguru import logger |
|
|
|
|
|
def create_temporal_edge_features(time_since_src, time_since_tgt, user_i, user_j): |
|
delta_t = torch.abs(time_since_src - time_since_tgt).float() |
|
hour_scale = torch.sin(delta_t / 3600) |
|
day_scale = torch.sin(delta_t / (24 * 3600)) |
|
week_scale = torch.sin(delta_t / (7 * 24 * 3600)) |
|
same_user = (user_i == user_j).float() |
|
burst_feature = same_user * torch.exp(-delta_t / (24 * 3600)) |
|
return torch.stack([hour_scale, day_scale, week_scale, burst_feature], dim=-1) |
|
|
|
|
|
class CustomMultiheadAttention(nn.Module): |
|
def __init__(self, embed_dim, num_heads): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.head_dim = embed_dim // num_heads |
|
|
|
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" |
|
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim) |
|
self.k_proj = nn.Linear(embed_dim, embed_dim) |
|
self.v_proj = nn.Linear(embed_dim, embed_dim) |
|
self.out_proj = nn.Linear(embed_dim, embed_dim) |
|
|
|
self.scale = self.head_dim ** -0.5 |
|
|
|
def forward(self, query, key, value, attn_bias=None): |
|
batch_size, seq_len, embed_dim = query.size() |
|
q = self.q_proj(query) |
|
k = self.k_proj(key) |
|
v = self.v_proj(value) |
|
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
|
if attn_bias is not None: |
|
scores = scores + attn_bias.unsqueeze(1) |
|
attn = F.softmax(scores, dim=-1) |
|
out = torch.matmul(attn, v) |
|
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) |
|
out = self.out_proj(out) |
|
return out, attn |
|
|
|
|
|
class HeteroGraphormer(nn.Module): |
|
def __init__(self, hidden_dim, output_dim, num_heads=4, edge_dim=4): |
|
super().__init__() |
|
self.hidden_dim = hidden_dim |
|
|
|
self.embed_dict = nn.ModuleDict({ |
|
'user': nn.Linear(14, hidden_dim), |
|
'business': nn.Linear(8, hidden_dim), |
|
'review': nn.Linear(16, hidden_dim) |
|
}) |
|
|
|
self.edge_proj = nn.Linear(edge_dim, hidden_dim) |
|
|
|
self.gru_user = nn.GRU(hidden_dim, hidden_dim, batch_first=True) |
|
self.gru_business = nn.GRU(hidden_dim, hidden_dim, batch_first=True) |
|
self.gru_review = nn.GRU(hidden_dim, hidden_dim, batch_first=True) |
|
|
|
self.attention1 = CustomMultiheadAttention(hidden_dim, num_heads) |
|
self.attention2 = CustomMultiheadAttention(hidden_dim, num_heads) |
|
|
|
self.ffn1 = nn.Sequential( |
|
nn.Linear(hidden_dim, hidden_dim * 4), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(hidden_dim * 4, hidden_dim) |
|
) |
|
self.ffn2 = nn.Sequential( |
|
nn.Linear(hidden_dim, hidden_dim * 4), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(hidden_dim * 4, hidden_dim) |
|
) |
|
|
|
self.norm1 = nn.LayerNorm(hidden_dim) |
|
self.norm2 = nn.LayerNorm(hidden_dim) |
|
self.norm3 = nn.LayerNorm(hidden_dim) |
|
self.norm4 = nn.LayerNorm(hidden_dim) |
|
|
|
self.centrality_proj = nn.Linear(1, hidden_dim) |
|
|
|
self.classifier = nn.Sequential( |
|
nn.Linear(hidden_dim * 3, hidden_dim), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(hidden_dim, 1) |
|
) |
|
|
|
self.dropout = nn.Dropout(0.1) |
|
|
|
def time_aware_aggregation(self, x, time_since, decay_rate=0.1): |
|
weights = torch.exp(-decay_rate * time_since.unsqueeze(-1)) |
|
return x * weights |
|
|
|
def forward(self, data, spatial_encoding, centrality_encoding, node_type_map, time_since_dict, edge_features_dict): |
|
x_dict = {} |
|
for node_type in data.x_dict: |
|
x = self.embed_dict[node_type](data[node_type].x) |
|
if node_type in time_since_dict: |
|
x = self.time_aware_aggregation(x, time_since_dict[node_type]) |
|
x_dict[node_type] = x |
|
|
|
x = torch.cat([x_dict['user'], x_dict['business'], x_dict['review']], dim=0) |
|
|
|
centrality = self.centrality_proj(centrality_encoding) |
|
x = x + centrality |
|
|
|
x = x.unsqueeze(0) |
|
|
|
x_user = x[:, :data['user'].x.size(0), :] |
|
x_business = x[:, data['user'].x.size(0):data['user'].x.size(0) + data['business'].x.size(0), :] |
|
x_review = x[:, data['user'].x.size(0) + data['business'].x.size(0):, :] |
|
|
|
x_user, _ = self.gru_user(x_user) |
|
x_business, _ = self.gru_business(x_business) |
|
x_review, _ = self.gru_review(x_review) |
|
|
|
x = torch.cat([x_user, x_business, x_review], dim=1) |
|
|
|
total_nodes = x.size(1) |
|
attn_bias = torch.zeros(1, total_nodes, total_nodes, device=x.device) |
|
attn_bias[0] = -spatial_encoding |
|
|
|
for edge_type in edge_features_dict: |
|
edge_index = data[edge_type].edge_index |
|
edge_feats = self.edge_proj(edge_features_dict[edge_type]) |
|
for i, (src, tgt) in enumerate(edge_index.t()): |
|
attn_bias[0, src, tgt] += edge_feats[i].sum() |
|
|
|
residual = x |
|
x, _ = self.attention1(x, x, x, attn_bias=attn_bias) |
|
x = self.norm1(x + residual) |
|
x = self.dropout(x) |
|
|
|
residual = x |
|
x = self.ffn1(x) |
|
x = self.norm2(x + residual) |
|
x = self.dropout(x) |
|
|
|
residual = x |
|
x, _ = self.attention2(x, x, x, attn_bias=attn_bias) |
|
x = self.norm3(x + residual) |
|
x = self.dropout(x) |
|
|
|
residual = x |
|
x = self.ffn2(x) |
|
x = self.norm4(x + residual) |
|
x = self.dropout(x) |
|
|
|
x = x.squeeze(0) |
|
|
|
user_start = 0 |
|
business_start = data['user'].x.size(0) |
|
review_start = business_start + data['business'].x.size(0) |
|
|
|
h_user = x[user_start:business_start] |
|
h_business = x[business_start:review_start] |
|
h_review = x[review_start:] |
|
|
|
user_indices = data['user', 'writes', 'review'].edge_index[0] |
|
business_indices = data['review', 'about', 'business'].edge_index[1] |
|
review_indices = data['user', 'writes', 'review'].edge_index[1] |
|
|
|
h_user_mapped = h_user[user_indices] |
|
h_business_mapped = h_business[business_indices] |
|
h_review_mapped = h_review[review_indices] |
|
|
|
combined = torch.cat([h_review_mapped, h_user_mapped, h_business_mapped], dim=-1) |
|
|
|
logits = self.classifier(combined) |
|
return torch.sigmoid(logits) |
|
|
|
|
|
class GraphformerModel: |
|
def __init__(self, df, output_path, epochs, test_size=0.3): |
|
self.df_whole = df |
|
self.output_path = output_path |
|
self.output_path = Path(self.output_path) / "GraphformerModel" |
|
self.epochs = epochs |
|
self.df, self.test_df = train_test_split(self.df_whole, test_size=test_size, random_state=42) |
|
|
|
torch.manual_seed(42) |
|
np.random.seed(42) |
|
|
|
Path(self.output_path).mkdir(parents=True, exist_ok=True) |
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.model = HeteroGraphormer(hidden_dim=64, output_dim=1, edge_dim=4).to(self.device) |
|
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.005) |
|
self.criterion = nn.BCELoss() |
|
|
|
def compute_graph_encodings(self, data): |
|
G = nx.DiGraph() |
|
node_offset = 0 |
|
node_type_map = {} |
|
|
|
for node_type in ['user', 'business', 'review']: |
|
num_nodes = data[node_type].x.size(0) |
|
for i in range(num_nodes): |
|
G.add_node(node_offset + i) |
|
node_type_map[node_offset + i] = node_type |
|
node_offset += num_nodes |
|
|
|
edge_types = [('user', 'writes', 'review'), ('review', 'about', 'business')] |
|
for src_type, rel, tgt_type in edge_types: |
|
edge_index = data[src_type, rel, tgt_type].edge_index |
|
src_nodes = edge_index[0].tolist() |
|
tgt_nodes = edge_index[1].tolist() |
|
src_offset = 0 if src_type == 'user' else (self.num_users if src_type == 'business' else self.num_users + self.num_businesses) |
|
tgt_offset = 0 if tgt_type == 'user' else (self.num_users if tgt_type == 'business' else self.num_users + self.num_businesses) |
|
for src, tgt in zip(src_nodes, tgt_nodes): |
|
G.add_edge(src + src_offset, tgt + tgt_offset) |
|
|
|
num_nodes = G.number_of_nodes() |
|
spatial_encoding = torch.full((num_nodes, num_nodes), float('inf'), device=self.device) |
|
for i in range(num_nodes): |
|
for j in range(num_nodes): |
|
if i == j: |
|
spatial_encoding[i, j] = 0 |
|
elif nx.has_path(G, i, j): |
|
spatial_encoding[i, j] = nx.shortest_path_length(G, i, j) |
|
|
|
centrality_encoding = torch.tensor([G.degree(i) for i in range(num_nodes)], dtype=torch.float, device=self.device).view(-1, 1) |
|
|
|
return spatial_encoding, centrality_encoding, node_type_map |
|
|
|
def compute_metrics(self, y_true, y_pred, y_prob, prefix=""): |
|
metrics = {} |
|
metrics[f"{prefix}accuracy"] = accuracy_score(y_true, y_pred) |
|
metrics[f"{prefix}precision"] = precision_score(y_true, y_pred, zero_division=0) |
|
metrics[f"{prefix}recall"] = recall_score(y_true, y_pred, zero_division=0) |
|
metrics[f"{prefix}f1"] = f1_score(y_true, y_pred, zero_division=0) |
|
metrics[f"{prefix}auc_roc"] = roc_auc_score(y_true, y_prob) |
|
metrics[f"{prefix}conf_matrix"] = confusion_matrix(y_true, y_pred) |
|
metrics[f"{prefix}class_report"] = classification_report(y_true, y_pred, output_dict=True, zero_division=0) |
|
return metrics |
|
|
|
def run_model(self): |
|
features = torch.tensor(self.df.drop(columns=['user_id', 'review_id', 'business_id', 'fake']).values, dtype=torch.float, device=self.device) |
|
y = torch.tensor(self.df['fake'].values, dtype=torch.float, device=self.device) |
|
time_since_user = torch.tensor(self.df['time_since_last_review_user'].values, dtype=torch.float, device=self.device) |
|
time_since_business = torch.tensor(self.df['time_since_last_review_business'].values, dtype=torch.float, device=self.device) |
|
num_rows = len(self.df) |
|
|
|
graph = HeteroData() |
|
|
|
self.num_users = len(self.df['user_id'].unique()) |
|
self.num_businesses = len(self.df['business_id'].unique()) |
|
|
|
user_indices = torch.tensor(self.df['user_id'].map({uid: i for i, uid in enumerate(self.df['user_id'].unique())}).values, dtype=torch.long, device=self.device) |
|
business_indices = torch.tensor(self.df['business_id'].map({bid: i for i, bid in enumerate(self.df['business_id'].unique())}).values, dtype=torch.long, device=self.device) |
|
review_indices = torch.arange(num_rows, dtype=torch.long, device=self.device) |
|
|
|
user_feats = torch.zeros(self.num_users, 14, device=self.device) |
|
business_feats = torch.zeros(self.num_businesses, 8, device=self.device) |
|
review_feats = torch.zeros(num_rows, 16, device=self.device) |
|
|
|
user_cols = ['hours', 'user_review_count', 'elite', 'friends', 'fans', 'average_stars', |
|
'time_since_last_review_user', 'user_account_age', 'user_degree', |
|
'user_review_burst_count', 'review_like_ratio', 'latest_checkin_hours', |
|
'user_useful_funny_cool', 'rating_variance_user'] |
|
business_cols = ['latitude', 'longitude', 'business_stars', 'business_review_count', |
|
'time_since_last_review_business', 'business_degree', |
|
'business_review_burst_count', 'rating_deviation_from_business_average'] |
|
review_cols = ['review_stars', 'tip_compliment_count', 'tip_count', 'average_time_between_reviews', |
|
'temporal_similarity', 'pronoun_density', 'avg_sentence_length', |
|
'excessive_punctuation_count', 'sentiment_polarity', 'good_severity', |
|
'bad_severity', 'code_switching_flag', 'grammar_error_score', |
|
'repetitive_words_count', 'similarity_to_other_reviews', 'review_useful_funny_cool'] |
|
|
|
for i in range(len(self.df)): |
|
user_idx = user_indices[i] |
|
business_idx = business_indices[i] |
|
user_feats[user_idx] += features[i, :14] |
|
business_feats[business_idx] += features[i, 14:22] |
|
review_feats = features[:, 22:38] |
|
|
|
graph['user'].x = user_feats |
|
graph['business'].x = business_feats |
|
graph['review'].x = review_feats |
|
graph['review'].y = y |
|
|
|
graph['user', 'writes', 'review'].edge_index = torch.stack([user_indices, review_indices], dim=0) |
|
graph['review', 'about', 'business'].edge_index = torch.stack([review_indices, business_indices], dim=0) |
|
|
|
edge_features_dict = {} |
|
user_writes_edge = graph['user', 'writes', 'review'].edge_index |
|
review_about_edge = graph['review', 'about', 'business'].edge_index |
|
|
|
src_users = user_indices[user_writes_edge[0]] |
|
tgt_reviews = review_indices[user_writes_edge[1]] |
|
edge_features_dict[('user', 'writes', 'review')] = create_temporal_edge_features( |
|
time_since_user[src_users], time_since_user[tgt_reviews], src_users, src_users |
|
) |
|
|
|
src_reviews = review_indices[review_about_edge[0]] |
|
tgt_businesses = business_indices[review_about_edge[1]] |
|
edge_features_dict[('review', 'about', 'business')] = create_temporal_edge_features( |
|
time_since_business[src_reviews], time_since_business[tgt_businesses], |
|
torch.zeros_like(src_reviews), torch.zeros_like(src_reviews) |
|
) |
|
|
|
user_time_since = self.df.groupby('user_id')['time_since_last_review_user'].min().reindex( |
|
self.df['user_id'].unique(), fill_value=0).values |
|
time_since_dict = { |
|
'user': torch.tensor(user_time_since, dtype=torch.float, device=self.device) |
|
} |
|
|
|
spatial_encoding, centrality_encoding, node_type_map = self.compute_graph_encodings(graph) |
|
|
|
|
|
self.model.train() |
|
train_metrics_history = [] |
|
for epoch in range(self.epochs): |
|
self.optimizer.zero_grad() |
|
out = self.model(graph, spatial_encoding, centrality_encoding, node_type_map, time_since_dict, edge_features_dict) |
|
loss = self.criterion(out.squeeze(), y) |
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
pred_labels = (out.squeeze() > 0.5).float() |
|
|
|
probs = out.squeeze().detach().cpu().numpy() |
|
train_metrics = self.compute_metrics(y.cpu().numpy(), pred_labels.cpu().numpy(), probs, prefix="train_") |
|
train_metrics['loss'] = loss.item() |
|
train_metrics_history.append(train_metrics) |
|
|
|
if epoch % 10 == 0: |
|
logger.info(f"Epoch {epoch}, Loss: {loss.item():.4f}, Accuracy: {train_metrics['train_accuracy']:.4f}, F1: {train_metrics['train_f1']:.4f}") |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
model_save_path = Path(self.output_path) / f"model_GraphformerModel_latest.pth" |
|
torch.save(self.model.state_dict(), model_save_path) |
|
|
|
|
|
if self.test_df is not None: |
|
test_features = torch.tensor(self.test_df.drop(columns=['user_id', 'review_id', 'business_id', 'fake']).values, dtype=torch.float, device=self.device) |
|
test_y = torch.tensor(self.test_df['fake'].values, dtype=torch.float, device=self.device) |
|
test_time_since_user = torch.tensor(self.test_df['time_since_last_review_user'].values, dtype=torch.float, device=self.device) |
|
test_time_since_business = torch.tensor(self.test_df['time_since_last_review_business'].values, dtype=torch.float, device=self.device) |
|
num_test_rows = len(self.test_df) |
|
|
|
new_user_unique = self.test_df['user_id'].unique() |
|
new_business_unique = self.test_df['business_id'].unique() |
|
|
|
existing_user_ids = list(self.df['user_id'].unique()) |
|
user_mapping = {uid: i for i, uid in enumerate(existing_user_ids)} |
|
total_users = self.num_users |
|
for uid in new_user_unique: |
|
if uid not in user_mapping: |
|
user_mapping[uid] = total_users |
|
total_users += 1 |
|
|
|
existing_business_ids = list(self.df['business_id'].unique()) |
|
business_mapping = {bid: i for i, bid in enumerate(existing_business_ids)} |
|
total_businesses = self.num_businesses |
|
for bid in new_business_unique: |
|
if bid not in business_mapping: |
|
business_mapping[bid] = total_businesses |
|
total_businesses += 1 |
|
|
|
new_user_indices = torch.tensor([user_mapping[uid] for uid in self.test_df['user_id']], dtype=torch.long, device=self.device) |
|
new_business_indices = torch.tensor([business_mapping[bid] for bid in self.test_df['business_id']], dtype=torch.long, device=self.device) |
|
new_review_indices = torch.arange(num_rows, num_rows + num_test_rows, device=self.device) |
|
|
|
if total_users > self.num_users: |
|
additional_user_feats = torch.zeros(total_users - self.num_users, 14, device=self.device) |
|
graph['user'].x = torch.cat([graph['user'].x, additional_user_feats], dim=0) |
|
if total_businesses > self.num_businesses: |
|
additional_business_feats = torch.zeros(total_businesses - self.num_businesses, 8, device=self.device) |
|
graph['business'].x = torch.cat([graph['business'].x, additional_business_feats], dim=0) |
|
|
|
for i in range(num_test_rows): |
|
user_idx = new_user_indices[i] |
|
business_idx = new_business_indices[i] |
|
if user_idx < graph['user'].x.size(0): |
|
graph['user'].x[user_idx] += test_features[i, :14] |
|
if business_idx < graph['business'].x.size(0): |
|
graph['business'].x[business_idx] += test_features[i, 14:22] |
|
graph['review'].x = torch.cat([graph['review'].x, test_features[:, 22:38]], dim=0) |
|
graph['review'].y = torch.cat([graph['review'].y, test_y], dim=0) |
|
|
|
graph['user', 'writes', 'review'].edge_index = torch.cat([ |
|
graph['user', 'writes', 'review'].edge_index, |
|
torch.stack([new_user_indices, new_review_indices], dim=0)], dim=1) |
|
graph['review', 'about', 'business'].edge_index = torch.cat([ |
|
graph['review', 'about', 'business'].edge_index, |
|
torch.stack([new_review_indices, new_business_indices], dim=0)], dim=1) |
|
|
|
all_time_since_user = torch.cat([time_since_user, test_time_since_user]) |
|
all_time_since_business = torch.cat([time_since_business, test_time_since_business]) |
|
all_user_indices = torch.cat([user_indices, new_user_indices]) |
|
all_business_indices = torch.cat([business_indices, new_business_indices]) |
|
all_review_indices = torch.cat([review_indices, new_review_indices]) |
|
|
|
user_writes_edge = graph['user', 'writes', 'review'].edge_index |
|
review_about_edge = graph['review', 'about', 'business'].edge_index |
|
|
|
edge_features_dict[('user', 'writes', 'review')] = create_temporal_edge_features( |
|
all_time_since_user[user_writes_edge[0]], all_time_since_user[user_writes_edge[1]], |
|
all_user_indices[user_writes_edge[0]], all_user_indices[user_writes_edge[0]] |
|
) |
|
edge_features_dict[('review', 'about', 'business')] = create_temporal_edge_features( |
|
all_time_since_business[review_about_edge[0]], all_time_since_business[review_about_edge[1]], |
|
torch.zeros_like(review_about_edge[0]), torch.zeros_like(review_about_edge[0]) |
|
) |
|
|
|
self.num_users = total_users |
|
self.num_businesses = total_businesses |
|
|
|
test_user_time_since = self.test_df.groupby('user_id')['time_since_last_review_user'].min().reindex( |
|
pd.Index(list(self.df['user_id'].unique()) + list(self.test_df['user_id'].unique())), fill_value=0).values |
|
time_since_dict['user'] = torch.tensor(test_user_time_since[:total_users], dtype=torch.float, device=self.device) |
|
|
|
spatial_encoding, centrality_encoding, node_type_map = self.compute_graph_encodings(graph) |
|
|
|
self.model.eval() |
|
with torch.no_grad(): |
|
out = self.model(graph, spatial_encoding, centrality_encoding, node_type_map, time_since_dict, edge_features_dict) |
|
pred_labels = (out.squeeze() > 0.5).float() |
|
probs = out.squeeze().detach().cpu().numpy() |
|
test_metrics = self.compute_metrics(graph['review'].y[-num_test_rows:].cpu().numpy(), pred_labels[-num_test_rows:].cpu().numpy(), probs[-num_test_rows:], prefix="test_") |
|
train_metrics = self.compute_metrics(y.cpu().numpy(), pred_labels[:num_rows].cpu().numpy(), probs[:num_rows], prefix="train_") |
|
logger.info(f"Test Accuracy: {test_metrics['test_accuracy']:.4f}, F1: {test_metrics['test_f1']:.4f}, AUC-ROC: {test_metrics['test_auc_roc']:.4f}") |
|
|
|
|
|
metrics_file = Path(self.output_path) / f"metrics_{timestamp}.txt" |
|
with open(metrics_file, 'w') as f: |
|
f.write("Training Metrics (Final Epoch):\n") |
|
for k, v in train_metrics.items(): |
|
f.write(f"{k}: {v}\n") |
|
f.write("\nTest Metrics:\n") |
|
for k, v in test_metrics.items(): |
|
f.write(f"{k}: {v}\n") |
|
|
|
|
|
plt.figure(figsize=(12, 8)) |
|
plt.plot([m['loss'] for m in train_metrics_history], label='Training Loss') |
|
plt.xlabel('Epoch') |
|
plt.ylabel('Loss') |
|
plt.title('Training Loss Curve') |
|
plt.legend() |
|
plt.grid(True) |
|
plt.savefig(Path(self.output_path) / f"loss_curve_{timestamp}.png") |
|
plt.close() |
|
|
|
plt.figure(figsize=(12, 8)) |
|
plt.plot([m['train_accuracy'] for m in train_metrics_history], label='Training Accuracy') |
|
plt.xlabel('Epoch') |
|
plt.ylabel('Accuracy') |
|
plt.title('Training Accuracy Curve') |
|
plt.legend() |
|
plt.grid(True) |
|
plt.savefig(Path(self.output_path) / f"accuracy_curve_{timestamp}.png") |
|
plt.close() |
|
|
|
plt.figure(figsize=(12, 8)) |
|
plt.plot([m['train_precision'] for m in train_metrics_history], label='Training Precision') |
|
plt.plot([m['train_recall'] for m in train_metrics_history], label='Training Recall') |
|
plt.plot([m['train_f1'] for m in train_metrics_history], label='Training F1-Score') |
|
plt.xlabel('Epoch') |
|
plt.ylabel('Score') |
|
plt.title('Training Precision, Recall, and F1-Score Curves') |
|
plt.legend() |
|
plt.grid(True) |
|
plt.savefig(Path(self.output_path) / f"prf1_curves_{timestamp}.png") |
|
plt.close() |
|
|
|
plt.figure(figsize=(12, 8)) |
|
plt.plot([m['train_auc_roc'] for m in train_metrics_history], label='Training AUC-ROC') |
|
plt.xlabel('Epoch') |
|
plt.ylabel('AUC-ROC') |
|
plt.title('Training AUC-ROC Curve') |
|
plt.legend() |
|
plt.grid(True) |
|
plt.savefig(Path(self.output_path) / f"auc_roc_curve_train_{timestamp}.png") |
|
plt.close() |
|
|
|
plt.figure(figsize=(8, 6)) |
|
sns.heatmap(test_metrics['test_conf_matrix'], annot=True, fmt='d', cmap='Blues', cbar=False) |
|
plt.xlabel('Predicted') |
|
plt.ylabel('True') |
|
plt.title('Test Confusion Matrix') |
|
plt.savefig(Path(self.output_path) / f"confusion_matrix_test_{timestamp}.png") |
|
plt.close() |
|
|
|
fpr, tpr, _ = roc_curve(graph['review'].y[-num_test_rows:].cpu().numpy(), probs[-num_test_rows:]) |
|
plt.figure(figsize=(10, 6)) |
|
plt.plot(fpr, tpr, label=f'Test ROC Curve (AUC = {test_metrics["test_auc_roc"]:.4f})') |
|
plt.plot([0, 1], [0, 1], 'k--', label='Random Guess') |
|
plt.xlabel('False Positive Rate') |
|
plt.ylabel('True Positive Rate') |
|
plt.title('Test ROC Curve') |
|
plt.legend() |
|
plt.grid(True) |
|
plt.savefig(Path(self.output_path) / f"roc_curve_test_{timestamp}.png") |
|
plt.close() |
|
|
|
plt.figure(figsize=(8, 6)) |
|
sns.heatmap(train_metrics['train_conf_matrix'], annot=True, fmt='d', cmap='Blues', cbar=False) |
|
plt.xlabel('Predicted') |
|
plt.ylabel('True') |
|
plt.title('Training Confusion Matrix (Final Epoch)') |
|
plt.savefig(Path(self.output_path) / f"confusion_matrix_train_{timestamp}.png") |
|
plt.close() |
|
|
|
fpr_train, tpr_train, _ = roc_curve(graph['review'].y[:num_rows].cpu().numpy(), probs[:num_rows]) |
|
plt.figure(figsize=(10, 6)) |
|
plt.plot(fpr_train, tpr_train, label=f'Training ROC Curve (AUC = {train_metrics["train_auc_roc"]:.4f})') |
|
plt.plot([0, 1], [0, 1], 'k--', label='Random Guess') |
|
plt.xlabel('False Positive Rate') |
|
plt.ylabel('True Positive Rate') |
|
plt.title('Training ROC Curve (Final Epoch)') |
|
plt.legend() |
|
plt.grid(True) |
|
plt.savefig(Path(self.output_path) / f"roc_curve_train_{timestamp}.png") |
|
plt.close() |
|
|
|
logger.info(f"All metrics, plots, and model saved to {self.output_path}") |
|
|
|
|