jerrynnms commited on
Commit
768506f
·
verified ·
1 Parent(s): ab9f78e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -44
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import os
2
 
3
- # ✅ 修正為 /tmp 可寫入的路徑,解決 PermissionError
4
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers"
5
  os.environ["HF_HOME"] = "/tmp/huggingface"
6
  os.environ["TORCH_HOME"] = "/tmp/torch"
7
 
8
  from fastapi import FastAPI, HTTPException
9
  from fastapi.middleware.cors import CORSMiddleware
 
 
10
  from pydantic import BaseModel
11
  from datetime import datetime
12
  from typing import Optional, List
@@ -18,12 +20,14 @@ import json
18
  import requests
19
  import torch
20
 
 
21
  app = FastAPI(
22
  title="詐騙訊息辨識 API",
23
  description="使用 BERT 模型分析輸入文字是否為詐騙內容",
24
  version="1.0.0"
25
  )
26
 
 
27
  app.add_middleware(
28
  CORSMiddleware,
29
  allow_origins=["*"],
@@ -32,18 +36,15 @@ app.add_middleware(
32
  allow_headers=["*"],
33
  )
34
 
35
- class TextAnalysisRequest(BaseModel):
36
- text: str
37
- user_id: Optional[str] = None
38
 
39
- class TextAnalysisResponse(BaseModel):
40
- status: str
41
- confidence: float
42
- suspicious_keywords: List[str]
43
- analysis_timestamp: datetime
44
- text_id: str
45
 
46
- # 初始化 Firebase 使用環境變數
47
  try:
48
  cred_data = os.getenv("FIREBASE_CREDENTIALS")
49
  if not cred_data:
@@ -54,76 +55,79 @@ try:
54
  except Exception as e:
55
  print(f"Firebase 初始化錯誤: {e}")
56
 
57
- # 從 Google Drive 載入 model.pth
58
- def load_model_from_drive():
59
  model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
 
60
  response = requests.get(model_url)
61
  if response.status_code == 200:
62
- with open("model.pth", "wb") as f:
63
  f.write(response.content)
64
- return True
65
- return False
66
 
67
- if not os.path.exists("model.pth"):
68
- if not load_model_from_drive():
69
- raise FileNotFoundError("無法從 Google Drive 載入 model.pth")
70
 
71
  from AI_Model_architecture import BertLSTM_CNN_Classifier
72
  model = BertLSTM_CNN_Classifier()
73
- model.load_state_dict(torch.load("model.pth", map_location="cpu"))
74
  model.eval()
75
 
76
- @app.get("/")
77
- async def root():
78
- return {"message": "詐騙文字辨識 API 已啟動", "version": "1.0.0", "status": "active", "docs": "/docs"}
 
 
 
 
 
 
 
 
79
 
 
80
  @app.post("/predict", response_model=TextAnalysisResponse)
81
  async def analyze_text_api(request: TextAnalysisRequest):
82
  try:
83
  tz = pytz.timezone("Asia/Taipei")
84
- taiwan_now = datetime.now(tz)
85
- collection_name = taiwan_now.strftime("%Y%m%d")
86
- document_id = taiwan_now.strftime("%Y%m%dT%H%M%S")
87
- timestamp_str = taiwan_now.strftime("%Y-%m-%d %H:%M:%S")
88
 
89
  result = bert_analyze_text(request.text)
90
 
91
  record = {
92
- "text_id": document_id,
93
  "text": request.text,
94
  "user_id": request.user_id,
95
- "analysis_result": {
96
- "status": result["status"],
97
- "confidence": result["confidence"],
98
- "suspicious_keywords": result["suspicious_keywords"],
99
- },
100
- "timestamp": timestamp_str,
101
  "type": "text_analysis"
102
  }
103
 
104
- db.collection(collection_name).document(document_id).set(record)
105
 
106
  return TextAnalysisResponse(
107
  status=result["status"],
108
  confidence=result["confidence"],
109
  suspicious_keywords=result["suspicious_keywords"],
110
- analysis_timestamp=taiwan_now,
111
- text_id=document_id
112
  )
113
  except Exception as e:
114
  raise HTTPException(status_code=500, detail=str(e))
115
-
116
- from fastapi.responses import JSONResponse
117
  @app.post("/feedback")
118
  async def save_user_feedback(feedback: dict):
119
  try:
120
  tz = pytz.timezone("Asia/Taipei")
121
- taiwan_now = datetime.now(tz)
122
- timestamp_str = taiwan_now.strftime("%Y-%m-%d %H:%M:%S")
123
-
124
- feedback["used_in_training"] = False
125
  feedback["timestamp"] = timestamp_str
126
-
127
  db.collection("user_feedback").add(feedback)
128
  return {"message": "✅ 已記錄使用者回饋"}
129
  except Exception as e:
 
1
  import os
2
 
3
+ # ✅ Hugging Face 建議路徑(防止 cache 錯誤)
4
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers"
5
  os.environ["HF_HOME"] = "/tmp/huggingface"
6
  os.environ["TORCH_HOME"] = "/tmp/torch"
7
 
8
  from fastapi import FastAPI, HTTPException
9
  from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.staticfiles import StaticFiles
11
+ from fastapi.responses import FileResponse, JSONResponse
12
  from pydantic import BaseModel
13
  from datetime import datetime
14
  from typing import Optional, List
 
20
  import requests
21
  import torch
22
 
23
+ # ✅ 初始化 FastAPI
24
  app = FastAPI(
25
  title="詐騙訊息辨識 API",
26
  description="使用 BERT 模型分析輸入文字是否為詐騙內容",
27
  version="1.0.0"
28
  )
29
 
30
+ # ✅ 跨域處理
31
  app.add_middleware(
32
  CORSMiddleware,
33
  allow_origins=["*"],
 
36
  allow_headers=["*"],
37
  )
38
 
39
+ # ✅ 掛載靜態檔案:支援 script.js / style.css
40
+ app.mount("/static", StaticFiles(directory="."), name="static")
 
41
 
42
+ # ✅ 回傳首頁 index.html
43
+ @app.get("/", response_class=FileResponse)
44
+ async def serve_index():
45
+ return FileResponse("index.html")
 
 
46
 
47
+ # Firebase 初始化
48
  try:
49
  cred_data = os.getenv("FIREBASE_CREDENTIALS")
50
  if not cred_data:
 
55
  except Exception as e:
56
  print(f"Firebase 初始化錯誤: {e}")
57
 
58
+ # Hugging Face Hub 載入模型(改為 /tmp)
59
+ def load_model_from_hub():
60
  model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
61
+ model_path = "/tmp/model.pth"
62
  response = requests.get(model_url)
63
  if response.status_code == 200:
64
+ with open(model_path, "wb") as f:
65
  f.write(response.content)
66
+ return model_path
67
+ raise FileNotFoundError("❌ 無法從 Hugging Face 載入 model.pth")
68
 
69
+ model_path = "/tmp/model.pth"
70
+ if not os.path.exists(model_path):
71
+ model_path = load_model_from_hub()
72
 
73
  from AI_Model_architecture import BertLSTM_CNN_Classifier
74
  model = BertLSTM_CNN_Classifier()
75
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
76
  model.eval()
77
 
78
+ # ✅ 資料格式
79
+ class TextAnalysisRequest(BaseModel):
80
+ text: str
81
+ user_id: Optional[str] = None
82
+
83
+ class TextAnalysisResponse(BaseModel):
84
+ status: str
85
+ confidence: float
86
+ suspicious_keywords: List[str]
87
+ analysis_timestamp: datetime
88
+ text_id: str
89
 
90
+ # ✅ /predict API
91
  @app.post("/predict", response_model=TextAnalysisResponse)
92
  async def analyze_text_api(request: TextAnalysisRequest):
93
  try:
94
  tz = pytz.timezone("Asia/Taipei")
95
+ now = datetime.now(tz)
96
+ doc_id = now.strftime("%Y%m%dT%H%M%S")
97
+ date_str = now.strftime("%Y-%m-%d %H:%M:%S")
98
+ collection = now.strftime("%Y%m%d")
99
 
100
  result = bert_analyze_text(request.text)
101
 
102
  record = {
103
+ "text_id": doc_id,
104
  "text": request.text,
105
  "user_id": request.user_id,
106
+ "analysis_result": result,
107
+ "timestamp": date_str,
 
 
 
 
108
  "type": "text_analysis"
109
  }
110
 
111
+ db.collection(collection).document(doc_id).set(record)
112
 
113
  return TextAnalysisResponse(
114
  status=result["status"],
115
  confidence=result["confidence"],
116
  suspicious_keywords=result["suspicious_keywords"],
117
+ analysis_timestamp=now,
118
+ text_id=doc_id
119
  )
120
  except Exception as e:
121
  raise HTTPException(status_code=500, detail=str(e))
122
+
123
+ # /feedback API
124
  @app.post("/feedback")
125
  async def save_user_feedback(feedback: dict):
126
  try:
127
  tz = pytz.timezone("Asia/Taipei")
128
+ timestamp_str = datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
 
 
 
129
  feedback["timestamp"] = timestamp_str
130
+ feedback["used_in_training"] = False
131
  db.collection("user_feedback").add(feedback)
132
  return {"message": "✅ 已記錄使用者回饋"}
133
  except Exception as e: