DuyKien016 commited on
Commit
32659a5
·
verified ·
1 Parent(s): f85223a

requirements.txt

Browse files

torch
torchvision
transformers
Pillow
gradio

Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
5
+ from PIL import Image, ImageEnhance, ImageOps
6
+ import torchvision.transforms as T
7
+
8
+ # ===================== CONFIG =====================
9
+ VINTERN_PATH = "5CD-AI/Vintern-1B-v3_5"
10
+ PHOBERT_PATH = "DuyKien016/phobert-scam-detector"
11
+
12
+ # ===================== LOAD MODELS =====================
13
+ print("🔄 Loading Vintern model...")
14
+ vintern_model = AutoModel.from_pretrained(
15
+ VINTERN_PATH,
16
+ trust_remote_code=True,
17
+ torch_dtype="auto",
18
+ device_map="auto",
19
+ low_cpu_mem_usage=True
20
+ ).eval()
21
+ vintern_tokenizer = AutoTokenizer.from_pretrained(
22
+ VINTERN_PATH,
23
+ trust_remote_code=True
24
+ )
25
+ print("✅ Vintern loaded.")
26
+
27
+ print("🔄 Loading PhoBERT model...")
28
+ phobert_tokenizer = AutoTokenizer.from_pretrained(PHOBERT_PATH, use_fast=False)
29
+ phobert_model = AutoModelForSequenceClassification.from_pretrained(PHOBERT_PATH).eval().to(
30
+ "cuda" if torch.cuda.is_available() else "cpu"
31
+ )
32
+ print("✅ PhoBERT loaded.")
33
+
34
+ # ===================== FUNCTIONS =====================
35
+ def ocr_vintern(image):
36
+ img = image.convert("RGB")
37
+ max_size = (448, 448)
38
+ img.thumbnail(max_size, Image.Resampling.LANCZOS)
39
+ img = ImageOps.pad(img, max_size, color=(255, 255, 255))
40
+ img = ImageEnhance.Contrast(img).enhance(1.5)
41
+
42
+ transform = T.Compose([
43
+ T.ToTensor(),
44
+ T.Normalize(mean=[0.485, 0.456, 0.406],
45
+ std=[0.229, 0.224, 0.225])
46
+ ])
47
+ pixel_values = transform(img).unsqueeze(0).to(vintern_model.device)
48
+
49
+ prompt = """
50
+ <image>
51
+ Hãy đọc nội dung trong ảnh chụp màn hình tin nhắn và xuất ra kết quả **chỉ** gồm các tin nhắn.
52
+
53
+ 📌 Quy tắc:
54
+ 1. Mỗi ô chat = 1 tin nhắn.
55
+ 2. Không giữ lại thời gian, tên người, emoji, icon hệ thống.
56
+ 3. Chỉ có văn bản thuần.
57
+ 4. Không thêm bình luận hoặc giải thích.
58
+
59
+ 📋 Định dạng:
60
+ Tin nhắn 1: ...
61
+ Tin nhắn 2: ...
62
+ Tin nhắn 3: ...
63
+ """
64
+
65
+ response, _ = vintern_model.chat(
66
+ tokenizer=vintern_tokenizer,
67
+ pixel_values=pixel_values,
68
+ question=prompt,
69
+ generation_config=dict(max_new_tokens=1024, do_sample=False, num_beams=3),
70
+ history=None,
71
+ return_history=True
72
+ )
73
+
74
+ messages = re.findall(r"Tin nhắn \d+: (.+?)(?=\nTin nhắn|\Z)", response, re.S)
75
+ cleaned_messages = [re.sub(r"\s+", " ", msg.strip()) for msg in messages if msg.strip()]
76
+ return cleaned_messages
77
+
78
+ def predict_phobert(texts):
79
+ results = []
80
+ for text in texts:
81
+ encoded = phobert_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
82
+ encoded = {k: v.to(phobert_model.device) for k, v in encoded.items()}
83
+
84
+ with torch.no_grad():
85
+ logits = phobert_model(**encoded).logits
86
+ probs = torch.softmax(logits, dim=1).squeeze()
87
+ label = torch.argmax(probs).item()
88
+
89
+ results.append({
90
+ "text": text,
91
+ "prediction": "LỪA ĐẢO" if label == 1 else "BÌNH THƯỜNG",
92
+ "confidence": f"{probs[label]*100:.2f}%"
93
+ })
94
+ return results
95
+
96
+ # ===================== GRADIO INTERFACE =====================
97
+ def detect(image, text):
98
+ if image is not None:
99
+ extracted_texts = ocr_vintern(image)
100
+ if not extracted_texts:
101
+ return "❌ Không đọc được nội dung từ ảnh"
102
+ results = predict_phobert(extracted_texts)
103
+ elif text.strip() != "":
104
+ results = predict_phobert([text])
105
+ else:
106
+ return "⚠️ Vui lòng nhập văn bản hoặc tải ảnh"
107
+
108
+ output_str = "\n".join([f"{r['text']} → {r['prediction']} ({r['confidence']})" for r in results])
109
+ return output_str
110
+
111
+ demo = gr.Interface(
112
+ fn=detect,
113
+ inputs=[
114
+ gr.Image(type="pil", label="Tải ảnh tin nhắn"),
115
+ gr.Textbox(label="Hoặc nhập văn bản")
116
+ ],
117
+ outputs=gr.Textbox(label="Kết quả"),
118
+ title="🛡️ Bộ phát hiện lừa đảo",
119
+ description="Nhập văn bản hoặc tải ảnh chụp màn hình tin nhắn để kiểm tra."
120
+ )
121
+
122
+ if __name__ == "__main__":
123
+ demo.launch()