jerrynnms commited on
Commit
f6b8f5d
·
verified ·
1 Parent(s): fc2fb82

Update bert_explainer.py

Browse files
Files changed (1) hide show
  1. bert_explainer.py +134 -18
bert_explainer.py CHANGED
@@ -1,18 +1,25 @@
1
- import torch
2
- from transformers import BertTokenizer, BertModel
3
- from AI_Model_architecture import BertLSTM_CNN_Classifier
4
- import re
5
  import os
 
 
 
 
6
  import requests
 
7
  import jieba
 
 
 
 
 
 
8
 
9
- # ✅ 裝置設定
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
- # 模型儲存與下載位置
13
  model_path = "/tmp/model.pth"
14
  model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
15
-
16
  if not os.path.exists(model_path):
17
  print("📦 下載 model.pth 中...")
18
  response = requests.get(model_url)
@@ -23,31 +30,36 @@ if not os.path.exists(model_path):
23
  else:
24
  raise FileNotFoundError("❌ 無法下載 model.pth")
25
 
26
- # 初始化 tokenizer
27
  tokenizer = BertTokenizer.from_pretrained("ckiplab/bert-base-chinese")
28
 
29
- # ✅ 初始化自訂分類模型
30
  model = BertLSTM_CNN_Classifier()
31
  model.load_state_dict(torch.load(model_path, map_location=device))
32
  model.to(device)
33
  model.eval()
34
 
35
- # 初始化原始 BERT 模型供 attention 使用
36
  bert_model = BertModel.from_pretrained("ckiplab/bert-base-chinese", output_attentions=True)
37
  bert_model.to(device)
38
  bert_model.eval()
 
39
 
40
- # ✅ 預測單句文字
41
 
 
 
42
  def predict_single_sentence(text: str, max_len=256):
 
43
  text = re.sub(r"\s+", "", text)
44
  text = re.sub(r"[^\u4e00-\u9fffA-Za-z0-9。,!?:/.\-]", "", text)
45
 
46
- encoded = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_len)
 
 
47
  input_ids = encoded["input_ids"].to(device)
48
  attention_mask = encoded["attention_mask"].to(device)
49
  token_type_ids = encoded["token_type_ids"].to(device)
50
 
 
51
  with torch.no_grad():
52
  output = model(input_ids, attention_mask, token_type_ids)
53
  prob = output.item()
@@ -55,24 +67,32 @@ def predict_single_sentence(text: str, max_len=256):
55
 
56
  return label, prob
57
 
58
- # ✅ 抽取高 attention token 並轉換為自然語意的詞句
59
 
 
60
  def extract_attention_keywords(text, top_k=5):
 
61
  cleaned = re.sub(r"\s+", "", text)
62
- encoded = tokenizer(cleaned, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
 
 
 
63
  input_ids = encoded["input_ids"].to(device)
64
  attention_mask = encoded["attention_mask"].to(device)
65
 
 
66
  with torch.no_grad():
67
  outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
68
- attentions = outputs.attentions
 
 
 
69
 
70
- attn = attentions[-1][0].mean(dim=0).mean(dim=0)
71
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
72
  top_indices = attn.topk(top_k).indices.tolist()
73
-
74
  top_tokens = [tokens[i] for i in top_indices if tokens[i] not in ["[CLS]", "[SEP]", "[PAD]"]]
75
 
 
76
  words = list(jieba.cut(text))
77
  suspicious = []
78
  for word in words:
@@ -83,11 +103,22 @@ def extract_attention_keywords(text, top_k=5):
83
  suspicious.append(word)
84
  break
85
 
 
86
  return suspicious[:top_k] if suspicious else top_tokens[:top_k]
 
87
 
88
- # ✅ 封裝主 API 用的分析函數
89
 
 
 
90
  def analyze_text(text: str):
 
 
 
 
 
 
 
 
91
  label, prob = predict_single_sentence(text)
92
  prob_percent = round(prob * 100, 2)
93
  status = "詐騙" if label == 1 else "正常"
@@ -98,3 +129,88 @@ def analyze_text(text: str):
98
  "confidence": prob_percent,
99
  "suspicious_keywords": suspicious or ["(模型未聚焦可疑詞)"]
100
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
3
+
4
+ import io
5
+ import re
6
  import requests
7
+ import torch
8
  import jieba
9
+ import numpy as np
10
+ import cv2
11
+ import easyocr
12
+ from PIL import Image
13
+ from transformers import BertTokenizer, BertModel
14
+ from AI_Model_architecture import BertLSTM_CNN_Classifier
15
 
16
+ # ─────────────────────────────────────────────────────────────────────────────
17
+ # 1. Device 設定
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
20
+ # 2. 下載並載入自訂分類模型(BertLSTM_CNN_Classifier)
21
  model_path = "/tmp/model.pth"
22
  model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
 
23
  if not os.path.exists(model_path):
24
  print("📦 下載 model.pth 中...")
25
  response = requests.get(model_url)
 
30
  else:
31
  raise FileNotFoundError("❌ 無法下載 model.pth")
32
 
33
+ # 3. 初始化 tokenizer 與自訂分類模型
34
  tokenizer = BertTokenizer.from_pretrained("ckiplab/bert-base-chinese")
35
 
 
36
  model = BertLSTM_CNN_Classifier()
37
  model.load_state_dict(torch.load(model_path, map_location=device))
38
  model.to(device)
39
  model.eval()
40
 
41
+ # 4. 初始化原始 BERT 模型(供 attention 使用)
42
  bert_model = BertModel.from_pretrained("ckiplab/bert-base-chinese", output_attentions=True)
43
  bert_model.to(device)
44
  bert_model.eval()
45
+ # ─────────────────────────────────────────────────────────────────────────────
46
 
 
47
 
48
+ # ─────────────────────────────────────────────────────────────────────────────
49
+ # 5. 預測單句文字函式
50
  def predict_single_sentence(text: str, max_len=256):
51
+ # 5.1. 簡單清洗:移除空白、保留中英文和部分標點
52
  text = re.sub(r"\s+", "", text)
53
  text = re.sub(r"[^\u4e00-\u9fffA-Za-z0-9。,!?:/.\-]", "", text)
54
 
55
+ # 5.2. Tokenize 並轉成 Tensor
56
+ encoded = tokenizer(text, return_tensors="pt", truncation=True,
57
+ padding="max_length", max_length=max_len)
58
  input_ids = encoded["input_ids"].to(device)
59
  attention_mask = encoded["attention_mask"].to(device)
60
  token_type_ids = encoded["token_type_ids"].to(device)
61
 
62
+ # 5.3. 模型推論
63
  with torch.no_grad():
64
  output = model(input_ids, attention_mask, token_type_ids)
65
  prob = output.item()
 
67
 
68
  return label, prob
69
 
 
70
 
71
+ # 6. 抽取高 attention token 並轉換為自然語意詞句
72
  def extract_attention_keywords(text, top_k=5):
73
+ # 6.1. 清洗文字(去除空白)
74
  cleaned = re.sub(r"\s+", "", text)
75
+
76
+ # 6.2. Tokenize 但只需要 attention,不需要分類模型
77
+ encoded = tokenizer(cleaned, return_tensors="pt", truncation=True,
78
+ padding="max_length", max_length=128)
79
  input_ids = encoded["input_ids"].to(device)
80
  attention_mask = encoded["attention_mask"].to(device)
81
 
82
+ # 6.3. 將文字丟給原始 BERT 取最後一層 attention
83
  with torch.no_grad():
84
  outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
85
+ attentions = outputs.attentions # tuple: 每層 transformer block 的 attention
86
+
87
+ # 6.4. 取最末層 attention,對所有 head、所有 token 均值 → 一維向量 (seq_len)
88
+ attn = attentions[-1][0].mean(dim=0).mean(dim=0) # shape: (seq_len,)
89
 
90
+ # 6.5. 取得該句所有 token,排除特殊 token
91
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
92
  top_indices = attn.topk(top_k).indices.tolist()
 
93
  top_tokens = [tokens[i] for i in top_indices if tokens[i] not in ["[CLS]", "[SEP]", "[PAD]"]]
94
 
95
+ # 6.6. 用 jieba 切詞,將高 attention 的 token 映射回中文詞組
96
  words = list(jieba.cut(text))
97
  suspicious = []
98
  for word in words:
 
103
  suspicious.append(word)
104
  break
105
 
106
+ # 6.7. 回傳 top_k 個「可疑詞」;若都沒有映射出詞,就直接回 top_tokens
107
  return suspicious[:top_k] if suspicious else top_tokens[:top_k]
108
+ # ─────────────────────────────────────────────────────────────────────────────
109
 
 
110
 
111
+ # ─────────────────────────────────────────────────────────────────────────────
112
+ # 7. 文字分析主函式:回傳完整結構
113
  def analyze_text(text: str):
114
+ """
115
+ 輸入一段文字(純文字),回傳:
116
+ {
117
+ "status": "詐騙" / "正常",
118
+ "confidence": float(百分比),
119
+ "suspicious_keywords": [已擷取詞列表]
120
+ }
121
+ """
122
  label, prob = predict_single_sentence(text)
123
  prob_percent = round(prob * 100, 2)
124
  status = "詐騙" if label == 1 else "正常"
 
129
  "confidence": prob_percent,
130
  "suspicious_keywords": suspicious or ["(模型未聚焦可疑詞)"]
131
  }
132
+ # ─────────────────────────────────────────────────────────────────────────────
133
+
134
+
135
+ # ─────────────────────────────────────────────────────────────────────────────
136
+ # 以下新增:OCR 前處理+圖片分析相關函式
137
+ # 8. 前處理:將圖片做灰階→CLAHE→HSV過濾→二值化→放大→模糊,回傳可供 OCR 的二值化 NumPy 圖
138
+ def preprocess_for_easyocr(pil_image: Image.Image) -> np.ndarray:
139
+ """
140
+ 將 PIL Image 做以下前處理,回傳黑底白字的二值化 NumPy 圖:
141
+ 1. PIL→NumPy (RGB→BGR)
142
+ 2. 轉灰階 + CLAHE(對比度增強)
143
+ 3. HSV 色彩過濾 (示範過濾「橘色」海報底色)
144
+ 4. 固定阈值反向二值化 (深色文字→白,其他→黑)
145
+ 5. 放大2倍 + GaussianBlur 模糊
146
+ """
147
+ # 8.1. PIL→NumPy (RGB to BGR)
148
+ img_bgr = np.array(pil_image.convert("RGB"))[:, :, ::-1]
149
+
150
+ # 8.2. 轉灰階
151
+ gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
152
+
153
+ # 8.3. CLAHE (對比度限制自適應直方圖均衡)
154
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
155
+ enhanced = clahe.apply(gray)
156
+
157
+ # 8.4. HSV 色彩過濾 (此範例針對橘色底色)
158
+ hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
159
+ lower_orange = np.array([5, 100, 100])
160
+ upper_orange = np.array([20, 255, 255])
161
+ mask_orange = cv2.inRange(hsv, lower_orange, upper_orange)
162
+ filtered = enhanced.copy()
163
+ filtered[mask_orange > 0] = 255 # 將橘色背景設為白
164
+
165
+ # 8.5. 固定阈值反向二值化 (深色文字→白,背景→黑)
166
+ _, thresh = cv2.threshold(filtered, 200, 255, cv2.THRESH_BINARY_INV)
167
+
168
+ # 8.6. 放大2倍 & GaussianBlur 平滑
169
+ scaled = cv2.resize(thresh, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
170
+ smoothed = cv2.GaussianBlur(scaled, (3, 3), 0)
171
+
172
+ return smoothed
173
+
174
+
175
+ # 9. 圖片分析:OCR 讀取前處理後影像 → 拼接文字 → BERT 分析
176
+ def analyze_image(file_bytes, explain_mode="cnn"):
177
+ """
178
+ 輸入圖片 bytes,返回:
179
+ {
180
+ "status": "詐騙"/"正常"/"無法辨識文字",
181
+ "confidence": float,
182
+ "suspicious_keywords": [詞列表]
183
+ }
184
+ 流程:
185
+ 1. bytes → PIL Image
186
+ 2. 前處理 → 得到二值化 NumPy 圖 (黑底白字)
187
+ 3. EasyOCR 讀取,組成一段長文字
188
+ 4. 沒讀到文字 → 回傳「無法辨識」;有文字 → 呼叫 analyze_text
189
+ """
190
+ # 9.1. bytes → PIL Image
191
+ image = Image.open(io.BytesIO(file_bytes))
192
+
193
+ # 9.2. OCR 前處理:得到 NumPy 二值化圖 (黑底白字)
194
+ preprocessed_np = preprocess_for_easyocr(image)
195
+
196
+ # 【可選 Debug】把前處理後的圖存到 /tmp/debug_processed.png,方便下載檢查
197
+ # Image.fromarray(preprocessed_np).save("/tmp/debug_processed.png")
198
+
199
+ # 9.3. 呼叫 EasyOCR 讀取前處理後影像
200
+ reader = easyocr.Reader(['ch_tra', 'en'], gpu=torch.cuda.is_available())
201
+ results = reader.readtext(preprocessed_np)
202
+
203
+ # 9.4. 合併所有識別到的文字
204
+ text = ' '.join([res[1] for res in results]).strip()
205
+
206
+ # 9.5. 如果沒讀到任何文字,直接回傳「無法辨識」
207
+ if not text:
208
+ return {
209
+ "status": "無法辨識文字",
210
+ "confidence": 0.0,
211
+ "suspicious_keywords": ["圖片中無可辨識的中文英文"]
212
+ }
213
+
214
+ # 9.6. 如果擷取到文字,交給 analyze_text 做 BERT 分析
215
+ return analyze_text(text, explain_mode=explain_mode)
216
+ # ─────────────────────────────────────────────────────────────────────────────