jerrynnms commited on
Commit
561a7b3
·
verified ·
1 Parent(s): db3a973

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -31
app.py CHANGED
@@ -1,35 +1,42 @@
1
  import os
2
-
3
- # ✅ 修正權限錯誤:改用 /tmp 資料夾儲存 cache
4
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache"
5
- os.environ["HF_HOME"] = "/tmp/.cache"
6
- os.environ["TORCH_HOME"] = "/tmp/.cache"
7
- os.environ["HF_DATASETS_CACHE"] = "/tmp/.cache"
 
 
 
 
 
8
 
9
  from fastapi import FastAPI, HTTPException, File, UploadFile
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.staticfiles import StaticFiles
12
  from fastapi.responses import FileResponse, JSONResponse
13
  from pydantic import BaseModel
14
- from datetime import datetime
15
- from typing import Optional, List
16
  from firebase_admin import credentials, firestore
17
  import firebase_admin
18
- import pytz
19
- import json
20
- import requests
21
- import torch
22
- import pytesseract
23
- import cv2
24
- import numpy as np
25
- from PIL import Image
26
- import io
27
  from AI_Model_architecture import BertLSTM_CNN_Classifier
28
  from bert_explainer import analyze_text as bert_analyze_text
29
 
 
 
 
 
 
 
 
 
 
 
 
30
  app = FastAPI(
31
  title="詐騙訊息辨識 API",
32
- description="使用 BERT 模型分析輸入文字是否為詐騙內容",
33
  version="1.0.0"
34
  )
35
 
@@ -41,22 +48,32 @@ app.add_middleware(
41
  allow_headers=["*"],
42
  )
43
 
 
44
  app.mount("/static", StaticFiles(directory="."), name="static")
45
 
46
  @app.get("/", response_class=FileResponse)
47
  async def serve_index():
 
 
 
48
  return FileResponse("index.html")
49
 
 
 
50
  try:
51
  cred_data = os.getenv("FIREBASE_CREDENTIALS")
52
  if not cred_data:
53
  raise ValueError("FIREBASE_CREDENTIALS 環境變數未設置")
54
- cred = credentials.Certificate({"type": "service_account", **json.loads(cred_data)})
55
- firebase_admin.initialize_app(cred)
56
  db = firestore.client()
57
  except Exception as e:
 
58
  print(f"Firebase 初始化錯誤: {e}")
 
59
 
 
 
60
  model_path = "/tmp/model.pth"
61
  model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
62
  if not os.path.exists(model_path):
@@ -65,12 +82,15 @@ if not os.path.exists(model_path):
65
  with open(model_path, "wb") as f:
66
  f.write(response.content)
67
  else:
68
- raise FileNotFoundError("❌ 無法從 Hugging Face 載入 model.pth")
69
 
70
  model = BertLSTM_CNN_Classifier()
71
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
72
  model.eval()
 
73
 
 
 
74
  class TextAnalysisRequest(BaseModel):
75
  text: str
76
  user_id: Optional[str] = None
@@ -81,9 +101,13 @@ class TextAnalysisResponse(BaseModel):
81
  suspicious_keywords: List[str]
82
  analysis_timestamp: datetime
83
  text_id: str
 
84
 
85
  @app.post("/predict", response_model=TextAnalysisResponse)
86
  async def analyze_text_api(request: TextAnalysisRequest):
 
 
 
87
  try:
88
  tz = pytz.timezone("Asia/Taipei")
89
  now = datetime.now(tz)
@@ -91,8 +115,10 @@ async def analyze_text_api(request: TextAnalysisRequest):
91
  date_str = now.strftime("%Y-%m-%d %H:%M:%S")
92
  collection = now.strftime("%Y%m%d")
93
 
 
94
  result = bert_analyze_text(request.text)
95
 
 
96
  record = {
97
  "text_id": doc_id,
98
  "text": request.text,
@@ -101,7 +127,6 @@ async def analyze_text_api(request: TextAnalysisRequest):
101
  "timestamp": date_str,
102
  "type": "text_analysis"
103
  }
104
-
105
  db.collection(collection).document(doc_id).set(record)
106
 
107
  return TextAnalysisResponse(
@@ -116,6 +141,9 @@ async def analyze_text_api(request: TextAnalysisRequest):
116
 
117
  @app.post("/feedback")
118
  async def save_user_feedback(feedback: dict):
 
 
 
119
  try:
120
  tz = pytz.timezone("Asia/Taipei")
121
  timestamp_str = datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
@@ -126,23 +154,79 @@ async def save_user_feedback(feedback: dict):
126
  except Exception as e:
127
  raise HTTPException(status_code=500, detail=str(e))
128
 
129
- # ✅ 加入 OpenCV 圖像前處理函數
130
- def preprocess_image_for_ocr(pil_image):
131
- img = np.array(pil_image.convert('RGB'))
132
- gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
133
- _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
134
- scaled = cv2.resize(thresh, None, fx=1.5, fy=1.5, interpolation=cv2.INTER_LINEAR)
135
- return Image.fromarray(scaled)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  @app.post("/analyze-image")
138
  async def analyze_uploaded_image(file: UploadFile = File(...)):
 
 
 
139
  try:
 
140
  image_bytes = await file.read()
141
  image = Image.open(io.BytesIO(image_bytes))
142
 
 
143
  processed_image = preprocess_image_for_ocr(image)
144
- extracted_text = pytesseract.image_to_string(processed_image, lang="chi_tra+eng").strip()
145
 
 
 
 
 
 
 
 
 
146
  if not extracted_text:
147
  return {
148
  "extracted_text": "",
@@ -153,6 +237,7 @@ async def analyze_uploaded_image(file: UploadFile = File(...)):
153
  }
154
  }
155
 
 
156
  result = bert_analyze_text(extracted_text)
157
 
158
  return {
@@ -161,4 +246,12 @@ async def analyze_uploaded_image(file: UploadFile = File(...)):
161
  }
162
 
163
  except Exception as e:
164
- raise HTTPException(status_code=500, detail=f"圖片辨識錯誤:{str(e)}")
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import io
3
+ import json
4
+ import requests
5
+ import torch
6
+ import pytz
7
+ import pytesseract
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+ from datetime import datetime
12
+ from typing import Optional, List
13
 
14
  from fastapi import FastAPI, HTTPException, File, UploadFile
15
  from fastapi.middleware.cors import CORSMiddleware
16
  from fastapi.staticfiles import StaticFiles
17
  from fastapi.responses import FileResponse, JSONResponse
18
  from pydantic import BaseModel
19
+
 
20
  from firebase_admin import credentials, firestore
21
  import firebase_admin
22
+
 
 
 
 
 
 
 
 
23
  from AI_Model_architecture import BertLSTM_CNN_Classifier
24
  from bert_explainer import analyze_text as bert_analyze_text
25
 
26
+ # ─────────────────────────────────────────────────────────────────────────────
27
+ # 0. 解決 Cache 權限問題:將各大 Cache 資料夾都指向 /tmp/.cache
28
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache"
29
+ os.environ["HF_HOME"] = "/tmp/.cache"
30
+ os.environ["TORCH_HOME"] = "/tmp/.cache"
31
+ os.environ["HF_DATASETS_CACHE"] = "/tmp/.cache"
32
+
33
+ # 1. 指定 Tesseract OCR 執行檔路徑(Hugging Face Space 預設已安裝 tesseract-ocr)
34
+ pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
35
+ # ─────────────────────────────────────────────────────────────────────────────
36
+
37
  app = FastAPI(
38
  title="詐騙訊息辨識 API",
39
+ description="使用 BERT 模型與 OCR 圖像前處理,辨識文字並做詐騙判斷",
40
  version="1.0.0"
41
  )
42
 
 
48
  allow_headers=["*"],
49
  )
50
 
51
+ # 掛載根目錄為靜態檔,用於提供 index.html
52
  app.mount("/static", StaticFiles(directory="."), name="static")
53
 
54
  @app.get("/", response_class=FileResponse)
55
  async def serve_index():
56
+ """
57
+ 回傳根目錄的 index.html
58
+ """
59
  return FileResponse("index.html")
60
 
61
+ # ─────────────────────────────────────────────────────────────────────────────
62
+ # 2. Firebase 初始化(以環境變數 FIREBASE_CREDENTIALS 儲存 service account JSON 字串)
63
  try:
64
  cred_data = os.getenv("FIREBASE_CREDENTIALS")
65
  if not cred_data:
66
  raise ValueError("FIREBASE_CREDENTIALS 環境變數未設置")
67
+ firebase_cred = credentials.Certificate({"type": "service_account", **json.loads(cred_data)})
68
+ firebase_admin.initialize_app(firebase_cred)
69
  db = firestore.client()
70
  except Exception as e:
71
+ # 若初始化失敗,印在 Console,但不讓整個 app 崩潰
72
  print(f"Firebase 初始化錯誤: {e}")
73
+ # ─────────────────────────────────────────────────────────────────────────────
74
 
75
+ # ─────────────────────────────────────────────────────────────────────────────
76
+ # 3. 下載並載入 PyTorch BERT+LSTM+CNN 模型
77
  model_path = "/tmp/model.pth"
78
  model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
79
  if not os.path.exists(model_path):
 
82
  with open(model_path, "wb") as f:
83
  f.write(response.content)
84
  else:
85
+ raise FileNotFoundError("❌ 無法從 Hugging Face 下載 model.pth")
86
 
87
  model = BertLSTM_CNN_Classifier()
88
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
89
  model.eval()
90
+ # ─────────────────────────────────────────────────────────────────────────────
91
 
92
+ # ─────────────────────────────────────────────────────────────────────────────
93
+ # 4. 定義 Pydantic Request / Response Model
94
  class TextAnalysisRequest(BaseModel):
95
  text: str
96
  user_id: Optional[str] = None
 
101
  suspicious_keywords: List[str]
102
  analysis_timestamp: datetime
103
  text_id: str
104
+ # ─────────────────────────────────────────────────────────────────────────────
105
 
106
  @app.post("/predict", response_model=TextAnalysisResponse)
107
  async def analyze_text_api(request: TextAnalysisRequest):
108
+ """
109
+ 文字輸入分析:回傳是否為詐騙訊息、信心度、可疑關鍵詞清單
110
+ """
111
  try:
112
  tz = pytz.timezone("Asia/Taipei")
113
  now = datetime.now(tz)
 
115
  date_str = now.strftime("%Y-%m-%d %H:%M:%S")
116
  collection = now.strftime("%Y%m%d")
117
 
118
+ # 使用 Bert+LSTM+CNN 模型做文字判斷
119
  result = bert_analyze_text(request.text)
120
 
121
+ # 把結果存到 Firestore
122
  record = {
123
  "text_id": doc_id,
124
  "text": request.text,
 
127
  "timestamp": date_str,
128
  "type": "text_analysis"
129
  }
 
130
  db.collection(collection).document(doc_id).set(record)
131
 
132
  return TextAnalysisResponse(
 
141
 
142
  @app.post("/feedback")
143
  async def save_user_feedback(feedback: dict):
144
+ """
145
+ 使用者回饋:把自訂的 feedback JSON 存到 Firestore 的 user_feedback collection
146
+ """
147
  try:
148
  tz = pytz.timezone("Asia/Taipei")
149
  timestamp_str = datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
 
154
  except Exception as e:
155
  raise HTTPException(status_code=500, detail=str(e))
156
 
157
+ # ─────────────────────────────────────────────────────────────────────────────
158
+ # 5. OCR 前處理:灰階 → 中值去噪 → 自適應二值化 → 形態學閉運算 → 校正傾斜 → 放大 & 平滑
159
+ def preprocess_image_for_ocr(pil_image: Image.Image) -> Image.Image:
160
+ """
161
+ 完整前處理邏輯:
162
+ 1. PIL Image (RGB) NumPy (BGR)
163
+ 2. 轉灰階
164
+ 3. 中值去噪 (MedianBlur)
165
+ 4. 自適應二值化 (Adaptive Threshold)
166
+ 5. 形態學閉運算 (Morphological Close)
167
+ 6. 校正傾斜 (Deskew)
168
+ 7. 放大兩倍 & GaussianBlur 平滑
169
+ 8. NumPy → PIL 回傳
170
+ """
171
+ # 1. PIL → NumPy (RGB -> BGR)
172
+ img = np.array(pil_image.convert("RGB"))[:, :, ::-1]
173
+ # 2. 轉灰階
174
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
175
+ # 3. 中值去噪
176
+ denoised = cv2.medianBlur(gray, 3)
177
+ # 4. 自適應二值化
178
+ thresh = cv2.adaptiveThreshold(
179
+ denoised, 255,
180
+ cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
181
+ cv2.THRESH_BINARY,
182
+ 11, 2
183
+ )
184
+ # 5. 形態學閉運算 (kernel=2x2)
185
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
186
+ morph = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=1)
187
+ # 6. 校正傾斜 (Deskew)
188
+ coords = np.column_stack(np.where(morph > 0))
189
+ if coords.shape[0] > 0:
190
+ angle = cv2.minAreaRect(coords)[-1]
191
+ if angle < -45:
192
+ angle = -(90 + angle)
193
+ else:
194
+ angle = -angle
195
+ (h, w) = morph.shape
196
+ M = cv2.getRotationMatrix2D((w // 2, h // 2), angle, 1.0)
197
+ morph = cv2.warpAffine(
198
+ morph, M, (w, h),
199
+ flags=cv2.INTER_CUBIC,
200
+ borderMode=cv2.BORDER_REPLICATE
201
+ )
202
+ # 7. 放大兩倍 & GaussianBlur 平滑
203
+ scaled = cv2.resize(morph, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
204
+ smoothed = cv2.GaussianBlur(scaled, (3, 3), 0)
205
+ # 8. NumPy → PIL
206
+ return Image.fromarray(smoothed)
207
+ # ─────────────────────────────────────────────────────────────────────────────
208
 
209
  @app.post("/analyze-image")
210
  async def analyze_uploaded_image(file: UploadFile = File(...)):
211
+ """
212
+ 圖片上傳並進行 OCR 辨識,擷取文字後再用 BERT 模型做詐騙分析
213
+ """
214
  try:
215
+ # 1. 讀取上傳的檔案 bytes
216
  image_bytes = await file.read()
217
  image = Image.open(io.BytesIO(image_bytes))
218
 
219
+ # 2. 對 PIL Image 做完整前處理
220
  processed_image = preprocess_image_for_ocr(image)
 
221
 
222
+ # 3. 帶參數呼叫 pytesseract OCR
223
+ custom_config = r"-l chi_tra+eng --oem 3 --psm 6"
224
+ extracted_text = pytesseract.image_to_string(
225
+ processed_image,
226
+ config=custom_config
227
+ ).strip()
228
+
229
+ # 如果 OCR 完全抓不到任何文字,就回傳「無法辨識」
230
  if not extracted_text:
231
  return {
232
  "extracted_text": "",
 
237
  }
238
  }
239
 
240
+ # 4. 如果擷取到文字,就套用 BERT 模型做詐騙分析
241
  result = bert_analyze_text(extracted_text)
242
 
243
  return {
 
246
  }
247
 
248
  except Exception as e:
249
+ # 任何錯誤都以 500 回傳
250
+ raise HTTPException(status_code=500, detail=f"圖片辨識錯誤:{str(e)}")
251
+
252
+ # ─────────────────────────────────────────────────────────────────────────────
253
+ # 6. 啟動程式入口:讓本機或 Hugging Face Space 都能用 uvicorn 直接執行
254
+ if __name__ == "__main__":
255
+ import uvicorn
256
+ port = int(os.environ.get("PORT", 7860))
257
+ uvicorn.run(app, host="0.0.0.0", port=port)