import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import SAGEConv, BatchNorm import pandas as pd import numpy as np import networkx as nx import matplotlib.pyplot as plt import os import json from huggingface_hub import hf_hub_download # ========================================== # 1. ĐỊNH NGHĨA MODEL # ========================================== class SAGE(nn.Module): def __init__(self, in_dim, h=128, out_dim=2, p_drop=0.3): super().__init__() self.conv1 = SAGEConv(in_dim, h, bias=True) self.bn1 = BatchNorm(h) self.conv2 = SAGEConv(h, h, bias=True) self.bn2 = BatchNorm(h) self.head = nn.Linear(h, out_dim) self.drop = nn.Dropout(p_drop) def forward(self, x, edge_index): x = self.conv1(x, edge_index) x = self.bn1(x) x = F.relu(x) x = self.drop(x) x = self.conv2(x, edge_index) x = self.bn2(x) x = F.relu(x) x = self.drop(x) return self.head(x) # ========================================== # 2. QUẢN LÝ RESOURCE # ========================================== REPO_ID = "uyen1109/eth-fraud-gnn-uyenuyen-v3" TOKEN = os.getenv("HF_TOKEN") GLOBAL_DATA = { "model": None, "df_scores": pd.DataFrame(), "df_edges": pd.DataFrame(), "feature_cols": [], "status": "Initializing..." } def smart_load_file(filename): """ Ưu tiên tìm ở root (theo hình ảnh user cung cấp). Thử có token -> không token. """ # Đảo ngược thứ tự: Tìm ở root trước vì hình ảnh cho thấy file ở root paths = [filename, f"hf_export/{filename}"] errs = [] for p in paths: try: # Cách 1: Dùng Token (cho Private Repo hoặc LFS) return hf_hub_download(repo_id=REPO_ID, filename=p, token=TOKEN) except Exception as e1: errs.append(f"Token fail {p}: {e1}") try: # Cách 2: Không dùng Token (cho Public Repo) return hf_hub_download(repo_id=REPO_ID, filename=p, token=None) except Exception as e2: errs.append(f"No-Token fail {p}: {e2}") continue print(f"⚠️ Failed to load {filename}. Details: {errs}") return None def load_resources(): logs = [] print("⏳ Starting Resource Loading...") # 1. Load Scores path = smart_load_file("node_scores_with_labels.csv") if path: try: df = pd.read_csv(path) # Tìm cột địa chỉ linh hoạt cols_lower = [c.lower() for c in df.columns] if "address" in cols_lower: addr_col = df.columns[cols_lower.index("address")] else: addr_col = df.columns[0] df[addr_col] = df[addr_col].astype(str).str.lower().str.strip() df.set_index(addr_col, inplace=True) GLOBAL_DATA["df_scores"] = df logs.append(f"✅ Loaded Scores: {len(df)} rows.") except Exception as e: logs.append(f"❌ Error parsing scores csv: {e}") else: logs.append("❌ 'node_scores_with_labels.csv' download failed.") # 2. Load Edges path = smart_load_file("edges_all.csv") if path: try: GLOBAL_DATA["df_edges"] = pd.read_csv(path, usecols=["src", "dst", "edge_type"]) # Chuẩn hóa nhẹ để vẽ hình GLOBAL_DATA["df_edges"]["src"] = GLOBAL_DATA["df_edges"]["src"].astype(str).str.lower().str.strip() GLOBAL_DATA["df_edges"]["dst"] = GLOBAL_DATA["df_edges"]["dst"].astype(str).str.lower().str.strip() print("✅ Loaded Edges.") except Exception as e: print(f"⚠️ Edge parsing error: {e}") else: print("⚠️ 'edges_all.csv' download failed.") # 3. Load Model & Features model_path = smart_load_file("pytorch_model.bin") if model_path: try: state_dict = torch.load(model_path, map_location=torch.device('cpu')) detected_dim = state_dict['conv1.lin_l.weight'].shape[1] model = SAGE(in_dim=detected_dim, h=128, out_dim=2, p_drop=0.3) model.load_state_dict(state_dict) model.eval() GLOBAL_DATA["model"] = model logs.append(f"✅ Model Loaded (Input Dim: {detected_dim})") # Load Feature Columns cols_path = smart_load_file("feature_columns.json") if cols_path: with open(cols_path, 'r') as f: cols = json.load(f) # Khớp số lượng feature if len(cols) == detected_dim: GLOBAL_DATA["feature_cols"] = cols elif len(cols) > detected_dim: GLOBAL_DATA["feature_cols"] = cols[:detected_dim] else: GLOBAL_DATA["feature_cols"] = cols + [f"Feat_{i}" for i in range(len(cols), detected_dim)] else: GLOBAL_DATA["feature_cols"] = [f"Feature_{i}" for i in range(detected_dim)] logs.append("⚠️ Using Dummy Feature Names (json missing)") except Exception as e: logs.append(f"❌ Model Init Error: {e}") else: logs.append("❌ 'pytorch_model.bin' NOT FOUND. Please upload it to Repo Root.") # Fallback feature list để UI không bị lỗi (dựa trên log của bạn) GLOBAL_DATA["feature_cols"] = [ 'out_deg', 'in_deg', 'eth_out_sum', 'eth_in_sum', 'unique_dst_cnt', 'unique_src_cnt', 'first_seen_ts', 'last_seen_ts', 'pr', 'clust_coef', 'betw', 'feat_11', 'feat_12', 'feat_13', 'feat_14' ] GLOBAL_DATA["status"] = "\n".join(logs) print(GLOBAL_DATA["status"]) load_resources() # ========================================== # 3. LOGIC XỬ LÝ # ========================================== def draw_graph(address): df = GLOBAL_DATA["df_edges"] if df.empty: return None subset = df[(df["src"] == address) | (df["dst"] == address)].head(20) if subset.empty: return None G = nx.from_pandas_edgelist(subset, "src", "dst", edge_attr="edge_type", create_using=nx.DiGraph()) plt.figure(figsize=(6, 6)) pos = nx.spring_layout(G, k=0.9, seed=42) node_colors = ["#FF4500" if n == address else "#1E90FF" for n in G.nodes()] nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=200, alpha=0.9) nx.draw_networkx_edges(G, pos, alpha=0.3, arrowstyle='->') nx.draw_networkx_labels(G, pos, labels={n: n[:4] for n in G.nodes()}, font_size=8) plt.title(f"Ego Graph: {address[:6]}...") plt.axis('off') return plt.gcf() def lookup_handler(address): if not address: return "Please enter an address.", None raw_addr = str(address).strip().lower() df = GLOBAL_DATA["df_scores"] # Logic tìm kiếm mạnh mẽ hơn found = None if not df.empty: if raw_addr in df.index: found = df.loc[raw_addr] elif raw_addr.replace("0x", "") in df.index: found = df.loc[raw_addr.replace("0x", "")] if found is not None: if isinstance(found, pd.DataFrame): found = found.iloc[0] score = float(found.get("prob_criminal", found.get("susp", 0.0))) return ( f"### ✅ Found\n**Score:** {score:.4f}\n**Status:** {'CRITICAL 🔴' if score > 0.5 else 'BENIGN 🟢'}", draw_graph(raw_addr) ) return ( f"### ❌ Not Found\nAddress `{raw_addr}` not in database.\nStatus Logs:\n{GLOBAL_DATA['status']}", None ) def predict_handler(*features): if GLOBAL_DATA["model"] is None: return f"❌ Model Error: pytorch_model.bin missing.\nPlease check 'System Status' below." try: x = torch.tensor([[float(f) for f in features]], dtype=torch.float) edge_index = torch.tensor([[], []], dtype=torch.long) with torch.no_grad(): prob = torch.softmax(GLOBAL_DATA["model"](x, edge_index), dim=1)[0][1].item() return f"### Result\n**Fraud Probability:** {prob*100:.2f}%" except Exception as e: return f"Error: {e}" # ========================================== # 4. UI SETUP # ========================================== with gr.Blocks(title="ETH Fraud GNN") as demo: gr.Markdown("# 🕵️‍♀️ Ethereum Fraud Inspector") with gr.Accordion("System Status (Click to Debug)", open=False): gr.Markdown(lambda: GLOBAL_DATA["status"]) # Dynamic update with gr.Tabs(): with gr.TabItem("🔍 Lookup"): with gr.Row(): inp = gr.Textbox(label="Address") btn = gr.Button("Search", variant="primary") with gr.Row(): out_txt = gr.Markdown() out_plt = gr.Plot() btn.click(lookup_handler, inputs=inp, outputs=[out_txt, out_plt]) with gr.TabItem("🧠 Predict"): gr.Markdown("### Inductive Prediction (Simulated)") # Render input dựa trên feature cols đã load cols = GLOBAL_DATA["feature_cols"] inputs = [] with gr.Row(): # Chia cột tự động c1, c2 = gr.Column(), gr.Column() for i, c in enumerate(cols): with (c1 if i % 2 == 0 else c2): inputs.append(gr.Number(label=c, value=0.0)) btn2 = gr.Button("Predict", variant="primary") out2 = gr.Markdown() btn2.click(predict_handler, inputs=inputs, outputs=out2) if __name__ == "__main__": demo.launch()