Spaces:
Sleeping
Sleeping
Update bert_explainer.py
Browse files- bert_explainer.py +134 -18
bert_explainer.py
CHANGED
|
@@ -1,18 +1,25 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from transformers import BertTokenizer, BertModel
|
| 3 |
-
from AI_Model_architecture import BertLSTM_CNN_Classifier
|
| 4 |
-
import re
|
| 5 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import requests
|
|
|
|
| 7 |
import jieba
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
#
|
|
|
|
| 10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
|
| 12 |
-
#
|
| 13 |
model_path = "/tmp/model.pth"
|
| 14 |
model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
|
| 15 |
-
|
| 16 |
if not os.path.exists(model_path):
|
| 17 |
print("📦 下載 model.pth 中...")
|
| 18 |
response = requests.get(model_url)
|
|
@@ -23,31 +30,36 @@ if not os.path.exists(model_path):
|
|
| 23 |
else:
|
| 24 |
raise FileNotFoundError("❌ 無法下載 model.pth")
|
| 25 |
|
| 26 |
-
#
|
| 27 |
tokenizer = BertTokenizer.from_pretrained("ckiplab/bert-base-chinese")
|
| 28 |
|
| 29 |
-
# ✅ 初始化自訂分類模型
|
| 30 |
model = BertLSTM_CNN_Classifier()
|
| 31 |
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 32 |
model.to(device)
|
| 33 |
model.eval()
|
| 34 |
|
| 35 |
-
#
|
| 36 |
bert_model = BertModel.from_pretrained("ckiplab/bert-base-chinese", output_attentions=True)
|
| 37 |
bert_model.to(device)
|
| 38 |
bert_model.eval()
|
|
|
|
| 39 |
|
| 40 |
-
# ✅ 預測單句文字
|
| 41 |
|
|
|
|
|
|
|
| 42 |
def predict_single_sentence(text: str, max_len=256):
|
|
|
|
| 43 |
text = re.sub(r"\s+", "", text)
|
| 44 |
text = re.sub(r"[^\u4e00-\u9fffA-Za-z0-9。,!?:/.\-]", "", text)
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
input_ids = encoded["input_ids"].to(device)
|
| 48 |
attention_mask = encoded["attention_mask"].to(device)
|
| 49 |
token_type_ids = encoded["token_type_ids"].to(device)
|
| 50 |
|
|
|
|
| 51 |
with torch.no_grad():
|
| 52 |
output = model(input_ids, attention_mask, token_type_ids)
|
| 53 |
prob = output.item()
|
|
@@ -55,24 +67,32 @@ def predict_single_sentence(text: str, max_len=256):
|
|
| 55 |
|
| 56 |
return label, prob
|
| 57 |
|
| 58 |
-
# ✅ 抽取高 attention token 並轉換為自然語意的詞句
|
| 59 |
|
|
|
|
| 60 |
def extract_attention_keywords(text, top_k=5):
|
|
|
|
| 61 |
cleaned = re.sub(r"\s+", "", text)
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
| 63 |
input_ids = encoded["input_ids"].to(device)
|
| 64 |
attention_mask = encoded["attention_mask"].to(device)
|
| 65 |
|
|
|
|
| 66 |
with torch.no_grad():
|
| 67 |
outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 68 |
-
attentions = outputs.attentions
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
|
| 71 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
| 72 |
top_indices = attn.topk(top_k).indices.tolist()
|
| 73 |
-
|
| 74 |
top_tokens = [tokens[i] for i in top_indices if tokens[i] not in ["[CLS]", "[SEP]", "[PAD]"]]
|
| 75 |
|
|
|
|
| 76 |
words = list(jieba.cut(text))
|
| 77 |
suspicious = []
|
| 78 |
for word in words:
|
|
@@ -83,11 +103,22 @@ def extract_attention_keywords(text, top_k=5):
|
|
| 83 |
suspicious.append(word)
|
| 84 |
break
|
| 85 |
|
|
|
|
| 86 |
return suspicious[:top_k] if suspicious else top_tokens[:top_k]
|
|
|
|
| 87 |
|
| 88 |
-
# ✅ 封裝主 API 用的分析函數
|
| 89 |
|
|
|
|
|
|
|
| 90 |
def analyze_text(text: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
label, prob = predict_single_sentence(text)
|
| 92 |
prob_percent = round(prob * 100, 2)
|
| 93 |
status = "詐騙" if label == 1 else "正常"
|
|
@@ -98,3 +129,88 @@ def analyze_text(text: str):
|
|
| 98 |
"confidence": prob_percent,
|
| 99 |
"suspicious_keywords": suspicious or ["(模型未聚焦可疑詞)"]
|
| 100 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
| 3 |
+
|
| 4 |
+
import io
|
| 5 |
+
import re
|
| 6 |
import requests
|
| 7 |
+
import torch
|
| 8 |
import jieba
|
| 9 |
+
import numpy as np
|
| 10 |
+
import cv2
|
| 11 |
+
import easyocr
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from transformers import BertTokenizer, BertModel
|
| 14 |
+
from AI_Model_architecture import BertLSTM_CNN_Classifier
|
| 15 |
|
| 16 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 17 |
+
# 1. Device 設定
|
| 18 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
|
| 20 |
+
# 2. 下載並載入自訂分類模型(BertLSTM_CNN_Classifier)
|
| 21 |
model_path = "/tmp/model.pth"
|
| 22 |
model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
|
|
|
|
| 23 |
if not os.path.exists(model_path):
|
| 24 |
print("📦 下載 model.pth 中...")
|
| 25 |
response = requests.get(model_url)
|
|
|
|
| 30 |
else:
|
| 31 |
raise FileNotFoundError("❌ 無法下載 model.pth")
|
| 32 |
|
| 33 |
+
# 3. 初始化 tokenizer 與自訂分類模型
|
| 34 |
tokenizer = BertTokenizer.from_pretrained("ckiplab/bert-base-chinese")
|
| 35 |
|
|
|
|
| 36 |
model = BertLSTM_CNN_Classifier()
|
| 37 |
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 38 |
model.to(device)
|
| 39 |
model.eval()
|
| 40 |
|
| 41 |
+
# 4. 初始化原始 BERT 模型(供 attention 使用)
|
| 42 |
bert_model = BertModel.from_pretrained("ckiplab/bert-base-chinese", output_attentions=True)
|
| 43 |
bert_model.to(device)
|
| 44 |
bert_model.eval()
|
| 45 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 46 |
|
|
|
|
| 47 |
|
| 48 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 49 |
+
# 5. 預測單句文字函式
|
| 50 |
def predict_single_sentence(text: str, max_len=256):
|
| 51 |
+
# 5.1. 簡單清洗:移除空白、保留中英文和部分標點
|
| 52 |
text = re.sub(r"\s+", "", text)
|
| 53 |
text = re.sub(r"[^\u4e00-\u9fffA-Za-z0-9。,!?:/.\-]", "", text)
|
| 54 |
|
| 55 |
+
# 5.2. Tokenize 並轉成 Tensor
|
| 56 |
+
encoded = tokenizer(text, return_tensors="pt", truncation=True,
|
| 57 |
+
padding="max_length", max_length=max_len)
|
| 58 |
input_ids = encoded["input_ids"].to(device)
|
| 59 |
attention_mask = encoded["attention_mask"].to(device)
|
| 60 |
token_type_ids = encoded["token_type_ids"].to(device)
|
| 61 |
|
| 62 |
+
# 5.3. 模型推論
|
| 63 |
with torch.no_grad():
|
| 64 |
output = model(input_ids, attention_mask, token_type_ids)
|
| 65 |
prob = output.item()
|
|
|
|
| 67 |
|
| 68 |
return label, prob
|
| 69 |
|
|
|
|
| 70 |
|
| 71 |
+
# 6. 抽取高 attention token 並轉換為自然語意詞句
|
| 72 |
def extract_attention_keywords(text, top_k=5):
|
| 73 |
+
# 6.1. 清洗文字(去除空白)
|
| 74 |
cleaned = re.sub(r"\s+", "", text)
|
| 75 |
+
|
| 76 |
+
# 6.2. Tokenize 但只需要 attention,不需要分類模型
|
| 77 |
+
encoded = tokenizer(cleaned, return_tensors="pt", truncation=True,
|
| 78 |
+
padding="max_length", max_length=128)
|
| 79 |
input_ids = encoded["input_ids"].to(device)
|
| 80 |
attention_mask = encoded["attention_mask"].to(device)
|
| 81 |
|
| 82 |
+
# 6.3. 將文字丟給原始 BERT 取最後一層 attention
|
| 83 |
with torch.no_grad():
|
| 84 |
outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 85 |
+
attentions = outputs.attentions # tuple: 每層 transformer block 的 attention
|
| 86 |
+
|
| 87 |
+
# 6.4. 取最末層 attention,對所有 head、所有 token 均值 → 一維向量 (seq_len)
|
| 88 |
+
attn = attentions[-1][0].mean(dim=0).mean(dim=0) # shape: (seq_len,)
|
| 89 |
|
| 90 |
+
# 6.5. 取得該句所有 token,排除特殊 token
|
| 91 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
| 92 |
top_indices = attn.topk(top_k).indices.tolist()
|
|
|
|
| 93 |
top_tokens = [tokens[i] for i in top_indices if tokens[i] not in ["[CLS]", "[SEP]", "[PAD]"]]
|
| 94 |
|
| 95 |
+
# 6.6. 用 jieba 切詞,將高 attention 的 token 映射回中文詞組
|
| 96 |
words = list(jieba.cut(text))
|
| 97 |
suspicious = []
|
| 98 |
for word in words:
|
|
|
|
| 103 |
suspicious.append(word)
|
| 104 |
break
|
| 105 |
|
| 106 |
+
# 6.7. 回傳 top_k 個「可疑詞」;若都沒有映射出詞,就直接回 top_tokens
|
| 107 |
return suspicious[:top_k] if suspicious else top_tokens[:top_k]
|
| 108 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 109 |
|
|
|
|
| 110 |
|
| 111 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 112 |
+
# 7. 文字分析主函式:回傳完整結構
|
| 113 |
def analyze_text(text: str):
|
| 114 |
+
"""
|
| 115 |
+
輸入一段文字(純文字),回傳:
|
| 116 |
+
{
|
| 117 |
+
"status": "詐騙" / "正常",
|
| 118 |
+
"confidence": float(百分比),
|
| 119 |
+
"suspicious_keywords": [已擷取詞列表]
|
| 120 |
+
}
|
| 121 |
+
"""
|
| 122 |
label, prob = predict_single_sentence(text)
|
| 123 |
prob_percent = round(prob * 100, 2)
|
| 124 |
status = "詐騙" if label == 1 else "正常"
|
|
|
|
| 129 |
"confidence": prob_percent,
|
| 130 |
"suspicious_keywords": suspicious or ["(模型未聚焦可疑詞)"]
|
| 131 |
}
|
| 132 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 136 |
+
# 以下新增:OCR 前處理+圖片分析相關函式
|
| 137 |
+
# 8. 前處理:將圖片做灰階→CLAHE→HSV過濾→二值化→放大→模糊,回傳可供 OCR 的二值化 NumPy 圖
|
| 138 |
+
def preprocess_for_easyocr(pil_image: Image.Image) -> np.ndarray:
|
| 139 |
+
"""
|
| 140 |
+
將 PIL Image 做以下前處理,回傳黑底白字的二值化 NumPy 圖:
|
| 141 |
+
1. PIL→NumPy (RGB→BGR)
|
| 142 |
+
2. 轉灰階 + CLAHE(對比度增強)
|
| 143 |
+
3. HSV 色彩過濾 (示範過濾「橘色」海報底色)
|
| 144 |
+
4. 固定阈值反向二值化 (深色文字→白,其他→黑)
|
| 145 |
+
5. 放大2倍 + GaussianBlur 模糊
|
| 146 |
+
"""
|
| 147 |
+
# 8.1. PIL→NumPy (RGB to BGR)
|
| 148 |
+
img_bgr = np.array(pil_image.convert("RGB"))[:, :, ::-1]
|
| 149 |
+
|
| 150 |
+
# 8.2. 轉灰階
|
| 151 |
+
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
| 152 |
+
|
| 153 |
+
# 8.3. CLAHE (對比度限制自適應直方圖均衡)
|
| 154 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 155 |
+
enhanced = clahe.apply(gray)
|
| 156 |
+
|
| 157 |
+
# 8.4. HSV 色彩過濾 (此範例針對橘色底色)
|
| 158 |
+
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
|
| 159 |
+
lower_orange = np.array([5, 100, 100])
|
| 160 |
+
upper_orange = np.array([20, 255, 255])
|
| 161 |
+
mask_orange = cv2.inRange(hsv, lower_orange, upper_orange)
|
| 162 |
+
filtered = enhanced.copy()
|
| 163 |
+
filtered[mask_orange > 0] = 255 # 將橘色背景設為白
|
| 164 |
+
|
| 165 |
+
# 8.5. 固定阈值反向二值化 (深色文字→白,背景→黑)
|
| 166 |
+
_, thresh = cv2.threshold(filtered, 200, 255, cv2.THRESH_BINARY_INV)
|
| 167 |
+
|
| 168 |
+
# 8.6. 放大2倍 & GaussianBlur 平滑
|
| 169 |
+
scaled = cv2.resize(thresh, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
|
| 170 |
+
smoothed = cv2.GaussianBlur(scaled, (3, 3), 0)
|
| 171 |
+
|
| 172 |
+
return smoothed
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# 9. 圖片分析:OCR 讀取前處理後影像 → 拼接文字 → BERT 分析
|
| 176 |
+
def analyze_image(file_bytes, explain_mode="cnn"):
|
| 177 |
+
"""
|
| 178 |
+
輸入圖片 bytes,返回:
|
| 179 |
+
{
|
| 180 |
+
"status": "詐騙"/"正常"/"無法辨識文字",
|
| 181 |
+
"confidence": float,
|
| 182 |
+
"suspicious_keywords": [詞列表]
|
| 183 |
+
}
|
| 184 |
+
流程:
|
| 185 |
+
1. bytes → PIL Image
|
| 186 |
+
2. 前處理 → 得到二值化 NumPy 圖 (黑底白字)
|
| 187 |
+
3. EasyOCR 讀取,組成一段長文字
|
| 188 |
+
4. 沒讀到文字 → 回傳「無法辨識」;有文字 → 呼叫 analyze_text
|
| 189 |
+
"""
|
| 190 |
+
# 9.1. bytes → PIL Image
|
| 191 |
+
image = Image.open(io.BytesIO(file_bytes))
|
| 192 |
+
|
| 193 |
+
# 9.2. OCR 前處理:得到 NumPy 二值化圖 (黑底白字)
|
| 194 |
+
preprocessed_np = preprocess_for_easyocr(image)
|
| 195 |
+
|
| 196 |
+
# 【可選 Debug】把前處理後的圖存到 /tmp/debug_processed.png,方便下載檢查
|
| 197 |
+
# Image.fromarray(preprocessed_np).save("/tmp/debug_processed.png")
|
| 198 |
+
|
| 199 |
+
# 9.3. 呼叫 EasyOCR 讀取前處理後影像
|
| 200 |
+
reader = easyocr.Reader(['ch_tra', 'en'], gpu=torch.cuda.is_available())
|
| 201 |
+
results = reader.readtext(preprocessed_np)
|
| 202 |
+
|
| 203 |
+
# 9.4. 合併所有識別到的文字
|
| 204 |
+
text = ' '.join([res[1] for res in results]).strip()
|
| 205 |
+
|
| 206 |
+
# 9.5. 如果沒讀到任何文字,直接回傳「無法辨識」
|
| 207 |
+
if not text:
|
| 208 |
+
return {
|
| 209 |
+
"status": "無法辨識文字",
|
| 210 |
+
"confidence": 0.0,
|
| 211 |
+
"suspicious_keywords": ["圖片中無可辨識的中文英文"]
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
# 9.6. 如果擷取到文字,交給 analyze_text 做 BERT 分析
|
| 215 |
+
return analyze_text(text, explain_mode=explain_mode)
|
| 216 |
+
# ─────────────────────────────────────────────────────────────────────────────
|