Update app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,7 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
Vietnamese Receipt Classification App for Hugging Face Spaces
|
| 4 |
-
|
| 5 |
-
1. Train model on startup
|
| 6 |
-
2. Vision Language Model for bill description
|
| 7 |
-
3. Classification prediction
|
| 8 |
"""
|
| 9 |
|
| 10 |
import os
|
|
@@ -19,7 +16,6 @@ import threading
|
|
| 19 |
import time
|
| 20 |
import io
|
| 21 |
from PIL import Image
|
| 22 |
-
import base64
|
| 23 |
|
| 24 |
# Add paths for imports
|
| 25 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
@@ -48,6 +44,7 @@ vectorizers = None
|
|
| 48 |
label_encoder = None
|
| 49 |
training_status = "Not started"
|
| 50 |
training_log = []
|
|
|
|
| 51 |
|
| 52 |
# ====================================
|
| 53 |
# GOOGLE AI VISION SETUP
|
|
@@ -57,7 +54,6 @@ def setup_google_ai():
|
|
| 57 |
if not GOOGLE_AI_AVAILABLE:
|
| 58 |
return None
|
| 59 |
|
| 60 |
-
# Get API key from environment or Hugging Face secrets
|
| 61 |
api_key = os.getenv('GOOGLE_AI_API_KEY') or os.getenv('GOOGLE_API_KEY')
|
| 62 |
|
| 63 |
if not api_key:
|
|
@@ -66,8 +62,6 @@ def setup_google_ai():
|
|
| 66 |
|
| 67 |
try:
|
| 68 |
genai.configure(api_key=api_key)
|
| 69 |
-
|
| 70 |
-
# Initialize vision model
|
| 71 |
model = genai.GenerativeModel('gemini-1.5-flash')
|
| 72 |
print("✅ Google AI Vision model initialized")
|
| 73 |
return model
|
|
@@ -82,9 +76,10 @@ google_vision_model = setup_google_ai()
|
|
| 82 |
# ====================================
|
| 83 |
def train_model_background():
|
| 84 |
"""Train model in background thread"""
|
| 85 |
-
global trained_model, feature_type, vectorizers, label_encoder, training_status, training_log
|
| 86 |
|
| 87 |
try:
|
|
|
|
| 88 |
training_status = "Starting training..."
|
| 89 |
training_log.append(f"[{datetime.now().strftime('%H:%M:%S')}] Starting training...")
|
| 90 |
|
|
@@ -92,9 +87,10 @@ def train_model_background():
|
|
| 92 |
if not os.path.exists(Config.DATA_FILE):
|
| 93 |
training_status = "Error: Dataset not found"
|
| 94 |
training_log.append(f"[{datetime.now().strftime('%H:%M:%S')}] ❌ Dataset {Config.DATA_FILE} not found")
|
|
|
|
| 95 |
return
|
| 96 |
|
| 97 |
-
training_status = "Training in progress..."
|
| 98 |
training_log.append(f"[{datetime.now().strftime('%H:%M:%S')}] 🚀 Training started")
|
| 99 |
|
| 100 |
# Initialize trainer
|
|
@@ -110,24 +106,33 @@ def train_model_background():
|
|
| 110 |
label_encoder = trainer.data_loader.label_encoder
|
| 111 |
|
| 112 |
accuracy = results.get('accuracy', 0)
|
| 113 |
-
training_status = f"Training completed! Accuracy: {accuracy:.4f}"
|
| 114 |
training_log.append(f"[{datetime.now().strftime('%H:%M:%S')}] ✅ Training completed with {accuracy:.4f} accuracy")
|
| 115 |
|
| 116 |
except Exception as e:
|
| 117 |
-
training_status = f"Training failed: {str(e)}"
|
| 118 |
training_log.append(f"[{datetime.now().strftime('%H:%M:%S')}] ❌ Training failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
def get_training_status():
|
| 121 |
"""Get current training status"""
|
| 122 |
-
log_text = "\n".join(training_log[-
|
| 123 |
return training_status, log_text
|
| 124 |
|
| 125 |
def start_training():
|
| 126 |
"""Start training process"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
thread = threading.Thread(target=train_model_background)
|
| 128 |
thread.daemon = True
|
| 129 |
thread.start()
|
| 130 |
-
|
|
|
|
| 131 |
|
| 132 |
# ====================================
|
| 133 |
# VISION MODEL FUNCTIONS
|
|
@@ -135,7 +140,7 @@ def start_training():
|
|
| 135 |
def extract_bill_description(image):
|
| 136 |
"""Extract bill description using Google Vision AI"""
|
| 137 |
if not GOOGLE_AI_AVAILABLE or google_vision_model is None:
|
| 138 |
-
return "❌ Google AI Vision không khả dụng. Vui lòng nhập mô tả thủ công."
|
| 139 |
|
| 140 |
try:
|
| 141 |
if image is None:
|
|
@@ -184,10 +189,10 @@ def predict_bill_class(description):
|
|
| 184 |
global trained_model, feature_type, vectorizers, label_encoder
|
| 185 |
|
| 186 |
if trained_model is None:
|
| 187 |
-
return "❌ Model chưa được train. Vui lòng đợi quá trình training hoàn tất.", "", ""
|
| 188 |
|
| 189 |
if not description or description.strip() == "":
|
| 190 |
-
return "❌ Vui lòng nhập mô tả hóa đơn", "", ""
|
| 191 |
|
| 192 |
try:
|
| 193 |
# Predict
|
|
@@ -209,11 +214,12 @@ def predict_bill_class(description):
|
|
| 209 |
|
| 210 |
result_text = f"🎯 Dự đoán: {predicted_class}\n📊 Độ tin cậy: {confidence:.3f}"
|
| 211 |
top_3_text = "📊 Top 3 dự đoán:\n" + "\n".join(top_3_results)
|
|
|
|
| 212 |
|
| 213 |
-
return result_text, top_3_text,
|
| 214 |
|
| 215 |
except Exception as e:
|
| 216 |
-
return f"❌ Lỗi khi dự đoán: {str(e)}", "", ""
|
| 217 |
|
| 218 |
def predict_from_image_and_text(image, manual_description):
|
| 219 |
"""Combined prediction from image and manual text"""
|
|
@@ -221,21 +227,24 @@ def predict_from_image_and_text(image, manual_description):
|
|
| 221 |
# Use manual description if provided, otherwise extract from image
|
| 222 |
if manual_description and manual_description.strip():
|
| 223 |
description = manual_description.strip()
|
| 224 |
-
|
| 225 |
elif image is not None:
|
| 226 |
description = extract_bill_description(image)
|
| 227 |
-
|
| 228 |
|
| 229 |
# Check if extraction failed
|
| 230 |
if description.startswith("❌"):
|
| 231 |
-
return description, "",
|
| 232 |
else:
|
| 233 |
-
return "❌ Vui lòng upload ảnh hoặc nhập mô tả thủ công", "", "", ""
|
| 234 |
|
| 235 |
# Make prediction
|
| 236 |
result, top_3, status = predict_bill_class(description)
|
| 237 |
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
# ====================================
|
| 241 |
# GRADIO INTERFACE
|
|
@@ -243,197 +252,322 @@ def predict_from_image_and_text(image, manual_description):
|
|
| 243 |
def create_interface():
|
| 244 |
"""Create Gradio interface"""
|
| 245 |
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
|
|
|
| 248 |
gr.HTML("""
|
| 249 |
-
<div
|
| 250 |
<h1>🧾 Vietnamese Receipt Classification</h1>
|
| 251 |
-
<p>Ứng dụng phân loại hóa đơn Việt Nam sử dụng GA-optimized
|
| 252 |
</div>
|
| 253 |
""")
|
| 254 |
|
| 255 |
-
with gr.Tabs():
|
| 256 |
|
| 257 |
# ====================================
|
| 258 |
# TAB 1: MODEL TRAINING
|
| 259 |
# ====================================
|
| 260 |
-
with gr.Tab("🚀 Model Training"):
|
| 261 |
-
|
|
|
|
|
|
|
| 262 |
|
| 263 |
with gr.Row():
|
| 264 |
with gr.Column(scale=1):
|
| 265 |
-
train_btn = gr.Button(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
with gr.Column(scale=1):
|
| 267 |
-
refresh_btn = gr.Button(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
| 274 |
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
| 281 |
|
|
|
|
| 282 |
gr.HTML("""
|
| 283 |
-
<div style="margin-top: 20px; padding:
|
| 284 |
-
<h4>📋 Training Information
|
| 285 |
-
<ul>
|
| 286 |
-
<li>Algorithm
|
| 287 |
-
<li>Features
|
| 288 |
-
<li>Optimization
|
| 289 |
-
<li>
|
|
|
|
|
|
|
| 290 |
</ul>
|
| 291 |
</div>
|
| 292 |
""")
|
| 293 |
|
| 294 |
-
# Event handlers
|
| 295 |
-
train_btn.click(
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
# ====================================
|
| 299 |
# TAB 2: BILL CLASSIFICATION
|
| 300 |
# ====================================
|
| 301 |
-
with gr.Tab("🔮 Bill Classification"):
|
| 302 |
|
| 303 |
-
gr.HTML("<h3
|
| 304 |
|
| 305 |
with gr.Row():
|
|
|
|
| 306 |
with gr.Column(scale=1):
|
| 307 |
gr.HTML("<h4>📸 Upload ảnh hóa đơn</h4>")
|
|
|
|
| 308 |
image_input = gr.Image(
|
| 309 |
label="Ảnh hóa đơn",
|
| 310 |
type="pil",
|
| 311 |
-
height=
|
| 312 |
)
|
| 313 |
|
| 314 |
-
extract_btn = gr.Button(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
gr.HTML("<h4>📝 Hoặc nhập mô tả thủ công</h4>")
|
|
|
|
| 317 |
manual_input = gr.Textbox(
|
| 318 |
label="Mô tả hóa đơn",
|
| 319 |
placeholder="Ví dụ: Hóa đơn thanh toán tại cửa hàng cà phê Feel Coffee với món Yogurt Very Berry giá 22.000 VND",
|
| 320 |
-
lines=
|
| 321 |
-
max_lines=
|
| 322 |
)
|
| 323 |
|
| 324 |
-
predict_btn = gr.Button(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
|
|
|
| 326 |
with gr.Column(scale=1):
|
| 327 |
-
gr.HTML("<h4>📄
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
|
|
|
| 331 |
interactive=False
|
| 332 |
)
|
| 333 |
|
| 334 |
gr.HTML("<h4>🎯 Kết quả phân loại</h4>")
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
|
|
|
| 338 |
interactive=False
|
| 339 |
)
|
| 340 |
|
| 341 |
-
|
| 342 |
label="Top 3 dự đoán",
|
| 343 |
lines=4,
|
| 344 |
interactive=False
|
| 345 |
)
|
| 346 |
|
| 347 |
-
|
| 348 |
label="Trạng thái",
|
| 349 |
-
lines=
|
| 350 |
interactive=False
|
| 351 |
)
|
| 352 |
|
| 353 |
-
#
|
| 354 |
-
gr.
|
| 355 |
-
|
| 356 |
-
<
|
| 357 |
-
|
| 358 |
-
<
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
-
# Event handlers
|
| 367 |
extract_btn.click(
|
| 368 |
-
process_image_and_extract,
|
| 369 |
inputs=[image_input],
|
| 370 |
outputs=[manual_input]
|
| 371 |
)
|
| 372 |
|
| 373 |
predict_btn.click(
|
| 374 |
-
predict_from_image_and_text,
|
| 375 |
inputs=[image_input, manual_input],
|
| 376 |
-
outputs=[
|
| 377 |
)
|
| 378 |
|
| 379 |
# ====================================
|
| 380 |
-
# TAB 3: ABOUT
|
| 381 |
# ====================================
|
| 382 |
-
with gr.Tab("ℹ️ About"):
|
|
|
|
| 383 |
gr.HTML("""
|
| 384 |
-
<div style="padding: 20px;">
|
| 385 |
-
<
|
| 386 |
|
| 387 |
-
<
|
| 388 |
-
|
| 389 |
-
<
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
<li><strong>Machine Learning:</strong> scikit-learn, sentence-transformers</li>
|
| 398 |
-
<li><strong>Optimization:</strong> DEAP (Genetic Algorithm)</li>
|
| 399 |
-
<li><strong>Computer Vision:</strong> Google Gemini Vision API</li>
|
| 400 |
-
<li><strong>Interface:</strong> Gradio, Hugging Face Spaces</li>
|
| 401 |
-
</ul>
|
| 402 |
|
| 403 |
-
<
|
| 404 |
-
|
| 405 |
-
<
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
-
<
|
| 411 |
-
|
| 412 |
-
<
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
-
<div style="margin
|
| 419 |
-
<
|
| 420 |
<ol>
|
| 421 |
-
<li>Bắt đầu
|
| 422 |
-
<li
|
| 423 |
-
<li>
|
| 424 |
-
<li>
|
| 425 |
-
<li>
|
|
|
|
|
|
|
| 426 |
</ol>
|
| 427 |
</div>
|
| 428 |
|
| 429 |
-
<div style="margin
|
| 430 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
</div>
|
| 432 |
</div>
|
| 433 |
""")
|
| 434 |
|
| 435 |
-
#
|
| 436 |
-
interface.load(
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
return interface
|
| 439 |
|
|
@@ -442,6 +576,7 @@ def create_interface():
|
|
| 442 |
# ====================================
|
| 443 |
if __name__ == "__main__":
|
| 444 |
print("🚀 Starting Vietnamese Receipt Classification App...")
|
|
|
|
| 445 |
|
| 446 |
# Check dependencies
|
| 447 |
print("📋 Checking dependencies...")
|
|
@@ -449,21 +584,27 @@ if __name__ == "__main__":
|
|
| 449 |
if GOOGLE_AI_AVAILABLE and google_vision_model is not None:
|
| 450 |
print("✅ Google AI Vision: Ready")
|
| 451 |
else:
|
| 452 |
-
print("⚠️ Google AI Vision: Not available
|
|
|
|
| 453 |
|
| 454 |
# Check dataset
|
| 455 |
if os.path.exists(Config.DATA_FILE):
|
| 456 |
print(f"✅ Dataset: Found {Config.DATA_FILE}")
|
| 457 |
else:
|
| 458 |
print(f"⚠️ Dataset: {Config.DATA_FILE} not found")
|
|
|
|
| 459 |
|
| 460 |
print("🎨 Creating Gradio interface...")
|
| 461 |
app = create_interface()
|
| 462 |
|
| 463 |
print("🌐 Launching app...")
|
|
|
|
|
|
|
|
|
|
| 464 |
app.launch(
|
| 465 |
-
share=True,
|
| 466 |
server_name="0.0.0.0",
|
| 467 |
server_port=7860,
|
| 468 |
-
|
| 469 |
-
|
|
|
|
|
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
Vietnamese Receipt Classification App for Hugging Face Spaces
|
| 4 |
+
Compatible with current Gradio version
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import os
|
|
|
|
| 16 |
import time
|
| 17 |
import io
|
| 18 |
from PIL import Image
|
|
|
|
| 19 |
|
| 20 |
# Add paths for imports
|
| 21 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
| 44 |
label_encoder = None
|
| 45 |
training_status = "Not started"
|
| 46 |
training_log = []
|
| 47 |
+
is_training = False
|
| 48 |
|
| 49 |
# ====================================
|
| 50 |
# GOOGLE AI VISION SETUP
|
|
|
|
| 54 |
if not GOOGLE_AI_AVAILABLE:
|
| 55 |
return None
|
| 56 |
|
|
|
|
| 57 |
api_key = os.getenv('GOOGLE_AI_API_KEY') or os.getenv('GOOGLE_API_KEY')
|
| 58 |
|
| 59 |
if not api_key:
|
|
|
|
| 62 |
|
| 63 |
try:
|
| 64 |
genai.configure(api_key=api_key)
|
|
|
|
|
|
|
| 65 |
model = genai.GenerativeModel('gemini-1.5-flash')
|
| 66 |
print("✅ Google AI Vision model initialized")
|
| 67 |
return model
|
|
|
|
| 76 |
# ====================================
|
| 77 |
def train_model_background():
|
| 78 |
"""Train model in background thread"""
|
| 79 |
+
global trained_model, feature_type, vectorizers, label_encoder, training_status, training_log, is_training
|
| 80 |
|
| 81 |
try:
|
| 82 |
+
is_training = True
|
| 83 |
training_status = "Starting training..."
|
| 84 |
training_log.append(f"[{datetime.now().strftime('%H:%M:%S')}] Starting training...")
|
| 85 |
|
|
|
|
| 87 |
if not os.path.exists(Config.DATA_FILE):
|
| 88 |
training_status = "Error: Dataset not found"
|
| 89 |
training_log.append(f"[{datetime.now().strftime('%H:%M:%S')}] ❌ Dataset {Config.DATA_FILE} not found")
|
| 90 |
+
is_training = False
|
| 91 |
return
|
| 92 |
|
| 93 |
+
training_status = "Training in progress... (This may take 10-15 minutes)"
|
| 94 |
training_log.append(f"[{datetime.now().strftime('%H:%M:%S')}] 🚀 Training started")
|
| 95 |
|
| 96 |
# Initialize trainer
|
|
|
|
| 106 |
label_encoder = trainer.data_loader.label_encoder
|
| 107 |
|
| 108 |
accuracy = results.get('accuracy', 0)
|
| 109 |
+
training_status = f"✅ Training completed! Accuracy: {accuracy:.4f}"
|
| 110 |
training_log.append(f"[{datetime.now().strftime('%H:%M:%S')}] ✅ Training completed with {accuracy:.4f} accuracy")
|
| 111 |
|
| 112 |
except Exception as e:
|
| 113 |
+
training_status = f"❌ Training failed: {str(e)}"
|
| 114 |
training_log.append(f"[{datetime.now().strftime('%H:%M:%S')}] ❌ Training failed: {str(e)}")
|
| 115 |
+
|
| 116 |
+
finally:
|
| 117 |
+
is_training = False
|
| 118 |
|
| 119 |
def get_training_status():
|
| 120 |
"""Get current training status"""
|
| 121 |
+
log_text = "\n".join(training_log[-15:]) # Last 15 messages
|
| 122 |
return training_status, log_text
|
| 123 |
|
| 124 |
def start_training():
|
| 125 |
"""Start training process"""
|
| 126 |
+
global is_training
|
| 127 |
+
|
| 128 |
+
if is_training:
|
| 129 |
+
return "⚠️ Training already in progress...", "\n".join(training_log[-15:])
|
| 130 |
+
|
| 131 |
thread = threading.Thread(target=train_model_background)
|
| 132 |
thread.daemon = True
|
| 133 |
thread.start()
|
| 134 |
+
|
| 135 |
+
return "🚀 Training started in background...", "Training initiated..."
|
| 136 |
|
| 137 |
# ====================================
|
| 138 |
# VISION MODEL FUNCTIONS
|
|
|
|
| 140 |
def extract_bill_description(image):
|
| 141 |
"""Extract bill description using Google Vision AI"""
|
| 142 |
if not GOOGLE_AI_AVAILABLE or google_vision_model is None:
|
| 143 |
+
return "❌ Google AI Vision không khả dụng. Vui lòng thiết lập GOOGLE_AI_API_KEY hoặc nhập mô tả thủ công."
|
| 144 |
|
| 145 |
try:
|
| 146 |
if image is None:
|
|
|
|
| 189 |
global trained_model, feature_type, vectorizers, label_encoder
|
| 190 |
|
| 191 |
if trained_model is None:
|
| 192 |
+
return "❌ Model chưa được train. Vui lòng đợi quá trình training hoàn tất.", "", "Model not ready"
|
| 193 |
|
| 194 |
if not description or description.strip() == "":
|
| 195 |
+
return "❌ Vui lòng nhập mô tả hóa đơn", "", "Empty description"
|
| 196 |
|
| 197 |
try:
|
| 198 |
# Predict
|
|
|
|
| 214 |
|
| 215 |
result_text = f"🎯 Dự đoán: {predicted_class}\n📊 Độ tin cậy: {confidence:.3f}"
|
| 216 |
top_3_text = "📊 Top 3 dự đoán:\n" + "\n".join(top_3_results)
|
| 217 |
+
status = f"✅ Đã phân loại thành công với độ tin cậy {confidence:.1%}"
|
| 218 |
|
| 219 |
+
return result_text, top_3_text, status
|
| 220 |
|
| 221 |
except Exception as e:
|
| 222 |
+
return f"❌ Lỗi khi dự đoán: {str(e)}", "", f"Error: {str(e)}"
|
| 223 |
|
| 224 |
def predict_from_image_and_text(image, manual_description):
|
| 225 |
"""Combined prediction from image and manual text"""
|
|
|
|
| 227 |
# Use manual description if provided, otherwise extract from image
|
| 228 |
if manual_description and manual_description.strip():
|
| 229 |
description = manual_description.strip()
|
| 230 |
+
source_info = "📝 Sử dụng mô tả thủ công"
|
| 231 |
elif image is not None:
|
| 232 |
description = extract_bill_description(image)
|
| 233 |
+
source_info = "🖼️ Trích xuất từ ảnh"
|
| 234 |
|
| 235 |
# Check if extraction failed
|
| 236 |
if description.startswith("❌"):
|
| 237 |
+
return description, "", description, description
|
| 238 |
else:
|
| 239 |
+
return "❌ Vui lòng upload ảnh hoặc nhập mô tả thủ công", "", "No input provided", ""
|
| 240 |
|
| 241 |
# Make prediction
|
| 242 |
result, top_3, status = predict_bill_class(description)
|
| 243 |
|
| 244 |
+
# Prepare full description info
|
| 245 |
+
full_description = f"{source_info}\n\n📄 Mô tả hóa đơn:\n{description}"
|
| 246 |
+
|
| 247 |
+
return result, top_3, status, full_description
|
| 248 |
|
| 249 |
# ====================================
|
| 250 |
# GRADIO INTERFACE
|
|
|
|
| 252 |
def create_interface():
|
| 253 |
"""Create Gradio interface"""
|
| 254 |
|
| 255 |
+
# Custom CSS
|
| 256 |
+
css = """
|
| 257 |
+
.gradio-container {
|
| 258 |
+
max-width: 1200px !important;
|
| 259 |
+
}
|
| 260 |
+
.main-header {
|
| 261 |
+
text-align: center;
|
| 262 |
+
margin: 20px 0;
|
| 263 |
+
padding: 20px;
|
| 264 |
+
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
|
| 265 |
+
color: white;
|
| 266 |
+
border-radius: 10px;
|
| 267 |
+
}
|
| 268 |
+
.tab-nav {
|
| 269 |
+
margin: 10px 0;
|
| 270 |
+
}
|
| 271 |
+
.status-box {
|
| 272 |
+
border: 2px solid #e1e5e9;
|
| 273 |
+
border-radius: 8px;
|
| 274 |
+
padding: 15px;
|
| 275 |
+
margin: 10px 0;
|
| 276 |
+
}
|
| 277 |
+
.success-status {
|
| 278 |
+
border-color: #28a745;
|
| 279 |
+
background-color: #f8fff9;
|
| 280 |
+
}
|
| 281 |
+
.error-status {
|
| 282 |
+
border-color: #dc3545;
|
| 283 |
+
background-color: #fff8f8;
|
| 284 |
+
}
|
| 285 |
+
.warning-status {
|
| 286 |
+
border-color: #ffc107;
|
| 287 |
+
background-color: #fffbf0;
|
| 288 |
+
}
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
with gr.Blocks(title="Vietnamese Receipt Classification", css=css, theme=gr.themes.Soft()) as interface:
|
| 292 |
|
| 293 |
+
# Header
|
| 294 |
gr.HTML("""
|
| 295 |
+
<div class="main-header">
|
| 296 |
<h1>🧾 Vietnamese Receipt Classification</h1>
|
| 297 |
+
<p>Ứng dụng phân loại hóa đơn Việt Nam sử dụng GA-optimized Ensemble + Google AI Vision</p>
|
| 298 |
</div>
|
| 299 |
""")
|
| 300 |
|
| 301 |
+
with gr.Tabs() as tabs:
|
| 302 |
|
| 303 |
# ====================================
|
| 304 |
# TAB 1: MODEL TRAINING
|
| 305 |
# ====================================
|
| 306 |
+
with gr.Tab("🚀 Model Training", id="training"):
|
| 307 |
+
|
| 308 |
+
with gr.Row():
|
| 309 |
+
gr.HTML("<h3>🏋️ Training Management</h3>")
|
| 310 |
|
| 311 |
with gr.Row():
|
| 312 |
with gr.Column(scale=1):
|
| 313 |
+
train_btn = gr.Button(
|
| 314 |
+
"🚀 Start Training",
|
| 315 |
+
variant="primary",
|
| 316 |
+
size="lg"
|
| 317 |
+
)
|
| 318 |
with gr.Column(scale=1):
|
| 319 |
+
refresh_btn = gr.Button(
|
| 320 |
+
"🔄 Refresh Status",
|
| 321 |
+
variant="secondary",
|
| 322 |
+
size="lg"
|
| 323 |
+
)
|
| 324 |
|
| 325 |
+
with gr.Row():
|
| 326 |
+
status_display = gr.Textbox(
|
| 327 |
+
label="📊 Training Status",
|
| 328 |
+
value="Click 'Start Training' to begin",
|
| 329 |
+
interactive=False,
|
| 330 |
+
lines=2
|
| 331 |
+
)
|
| 332 |
|
| 333 |
+
with gr.Row():
|
| 334 |
+
log_display = gr.Textbox(
|
| 335 |
+
label="📝 Training Log",
|
| 336 |
+
lines=10,
|
| 337 |
+
max_lines=15,
|
| 338 |
+
interactive=False,
|
| 339 |
+
placeholder="Training logs will appear here..."
|
| 340 |
+
)
|
| 341 |
|
| 342 |
+
# Training info
|
| 343 |
gr.HTML("""
|
| 344 |
+
<div style="margin-top: 20px; padding: 20px; background-color: #f8f9fa; border-radius: 8px; border-left: 4px solid #007bff;">
|
| 345 |
+
<h4>📋 Training Information</h4>
|
| 346 |
+
<ul style="margin: 10px 0; padding-left: 20px;">
|
| 347 |
+
<li><strong>Algorithm:</strong> GA-optimized Voting Ensemble (KNN + Decision Tree + Naive Bayes)</li>
|
| 348 |
+
<li><strong>Features:</strong> BoW, TF-IDF, Sentence Embeddings (all-MiniLM-L6-v2)</li>
|
| 349 |
+
<li><strong>Optimization:</strong> Genetic Algorithm (Population: 20, Generations: 10)</li>
|
| 350 |
+
<li><strong>Evaluation:</strong> 3-fold Cross-Validation</li>
|
| 351 |
+
<li><strong>Expected Time:</strong> 10-15 minutes on free tier</li>
|
| 352 |
+
<li><strong>Expected Accuracy:</strong> 85-95% depending on dataset quality</li>
|
| 353 |
</ul>
|
| 354 |
</div>
|
| 355 |
""")
|
| 356 |
|
| 357 |
+
# Event handlers for training tab
|
| 358 |
+
train_btn.click(
|
| 359 |
+
fn=start_training,
|
| 360 |
+
outputs=[status_display, log_display]
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
refresh_btn.click(
|
| 364 |
+
fn=get_training_status,
|
| 365 |
+
outputs=[status_display, log_display]
|
| 366 |
+
)
|
| 367 |
|
| 368 |
# ====================================
|
| 369 |
# TAB 2: BILL CLASSIFICATION
|
| 370 |
# ====================================
|
| 371 |
+
with gr.Tab("🔮 Bill Classification", id="classification"):
|
| 372 |
|
| 373 |
+
gr.HTML("<h3>🎯 Phân loại hóa đơn từ ảnh hoặc text</h3>")
|
| 374 |
|
| 375 |
with gr.Row():
|
| 376 |
+
# Left column - Input
|
| 377 |
with gr.Column(scale=1):
|
| 378 |
gr.HTML("<h4>📸 Upload ảnh hóa đơn</h4>")
|
| 379 |
+
|
| 380 |
image_input = gr.Image(
|
| 381 |
label="Ảnh hóa đơn",
|
| 382 |
type="pil",
|
| 383 |
+
height=250
|
| 384 |
)
|
| 385 |
|
| 386 |
+
extract_btn = gr.Button(
|
| 387 |
+
"🔍 Trích xuất mô tả từ ảnh",
|
| 388 |
+
variant="secondary",
|
| 389 |
+
size="sm"
|
| 390 |
+
)
|
| 391 |
|
| 392 |
gr.HTML("<h4>📝 Hoặc nhập mô tả thủ công</h4>")
|
| 393 |
+
|
| 394 |
manual_input = gr.Textbox(
|
| 395 |
label="Mô tả hóa đơn",
|
| 396 |
placeholder="Ví dụ: Hóa đơn thanh toán tại cửa hàng cà phê Feel Coffee với món Yogurt Very Berry giá 22.000 VND",
|
| 397 |
+
lines=4,
|
| 398 |
+
max_lines=6
|
| 399 |
)
|
| 400 |
|
| 401 |
+
predict_btn = gr.Button(
|
| 402 |
+
"🎯 Dự đoán phân loại",
|
| 403 |
+
variant="primary",
|
| 404 |
+
size="lg"
|
| 405 |
+
)
|
| 406 |
|
| 407 |
+
# Right column - Output
|
| 408 |
with gr.Column(scale=1):
|
| 409 |
+
gr.HTML("<h4>📄 Thông tin đã xử lý</h4>")
|
| 410 |
+
|
| 411 |
+
processed_info = gr.Textbox(
|
| 412 |
+
label="Nguồn và mô tả",
|
| 413 |
+
lines=6,
|
| 414 |
interactive=False
|
| 415 |
)
|
| 416 |
|
| 417 |
gr.HTML("<h4>🎯 Kết quả phân loại</h4>")
|
| 418 |
+
|
| 419 |
+
result_display = gr.Textbox(
|
| 420 |
+
label="Dự đoán chính",
|
| 421 |
+
lines=3,
|
| 422 |
interactive=False
|
| 423 |
)
|
| 424 |
|
| 425 |
+
top3_display = gr.Textbox(
|
| 426 |
label="Top 3 dự đoán",
|
| 427 |
lines=4,
|
| 428 |
interactive=False
|
| 429 |
)
|
| 430 |
|
| 431 |
+
status_output = gr.Textbox(
|
| 432 |
label="Trạng thái",
|
| 433 |
+
lines=2,
|
| 434 |
interactive=False
|
| 435 |
)
|
| 436 |
|
| 437 |
+
# Examples section
|
| 438 |
+
with gr.Row():
|
| 439 |
+
gr.HTML("""
|
| 440 |
+
<div style="margin-top: 20px; padding: 15px; background-color: #e8f4fd; border-radius: 8px;">
|
| 441 |
+
<h4>💡 Ví dụ các loại hóa đơn</h4>
|
| 442 |
+
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px; margin-top: 10px;">
|
| 443 |
+
<div>
|
| 444 |
+
<ul style="margin: 0; padding-left: 20px;">
|
| 445 |
+
<li><strong>Ăn uống ngoài hàng:</strong> Nhà hàng, quán cà phê, fast food</li>
|
| 446 |
+
<li><strong>Siêu thị tổng hợp:</strong> VinMart, Co.opMart, Big C, Lotte</li>
|
| 447 |
+
</ul>
|
| 448 |
+
</div>
|
| 449 |
+
<div>
|
| 450 |
+
<ul style="margin: 0; padding-left: 20px;">
|
| 451 |
+
<li><strong>Sữa & Đồ uống:</strong> Sữa, nước ngọt, đồ uống các loại</li>
|
| 452 |
+
<li><strong>Tiện ích:</strong> Điện, nước, internet, di động</li>
|
| 453 |
+
</ul>
|
| 454 |
+
</div>
|
| 455 |
+
</div>
|
| 456 |
+
</div>
|
| 457 |
+
""")
|
| 458 |
|
| 459 |
+
# Event handlers for classification tab
|
| 460 |
extract_btn.click(
|
| 461 |
+
fn=process_image_and_extract,
|
| 462 |
inputs=[image_input],
|
| 463 |
outputs=[manual_input]
|
| 464 |
)
|
| 465 |
|
| 466 |
predict_btn.click(
|
| 467 |
+
fn=predict_from_image_and_text,
|
| 468 |
inputs=[image_input, manual_input],
|
| 469 |
+
outputs=[result_display, top3_display, status_output, processed_info]
|
| 470 |
)
|
| 471 |
|
| 472 |
# ====================================
|
| 473 |
+
# TAB 3: ABOUT & HELP
|
| 474 |
# ====================================
|
| 475 |
+
with gr.Tab("ℹ️ About & Help", id="about"):
|
| 476 |
+
|
| 477 |
gr.HTML("""
|
| 478 |
+
<div style="padding: 20px; max-width: 800px; margin: 0 auto;">
|
| 479 |
+
<h2>🧾 Vietnamese Receipt Classification System</h2>
|
| 480 |
|
| 481 |
+
<div style="margin: 20px 0; padding: 15px; background-color: #f8f9fa; border-radius: 8px;">
|
| 482 |
+
<h3>🎯 Tính năng chính</h3>
|
| 483 |
+
<ul>
|
| 484 |
+
<li><strong>🤖 AI Vision:</strong> Trích xuất mô tả từ ảnh hóa đơn bằng Google Gemini Vision API</li>
|
| 485 |
+
<li><strong>🧬 GA Optimization:</strong> Tối ưu hóa ensemble classifier bằng Genetic Algorithm</li>
|
| 486 |
+
<li><strong>📊 Multi-feature:</strong> Kết hợp BoW, TF-IDF và Sentence Embeddings</li>
|
| 487 |
+
<li><strong>🗳️ Voting Ensemble:</strong> KNN + Decision Tree + Naive Bayes với trọng số tối ưu</li>
|
| 488 |
+
<li><strong>⚡ Real-time:</strong> Training và prediction trực tiếp trên web</li>
|
| 489 |
+
</ul>
|
| 490 |
+
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
|
| 492 |
+
<div style="margin: 20px 0; padding: 15px; background-color: #e8f4fd; border-radius: 8px;">
|
| 493 |
+
<h3>🔧 Công nghệ sử dụng</h3>
|
| 494 |
+
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 15px;">
|
| 495 |
+
<div>
|
| 496 |
+
<h4>Machine Learning:</h4>
|
| 497 |
+
<ul>
|
| 498 |
+
<li>scikit-learn</li>
|
| 499 |
+
<li>sentence-transformers</li>
|
| 500 |
+
<li>DEAP (Genetic Algorithm)</li>
|
| 501 |
+
</ul>
|
| 502 |
+
</div>
|
| 503 |
+
<div>
|
| 504 |
+
<h4>AI Vision:</h4>
|
| 505 |
+
<ul>
|
| 506 |
+
<li>Google Gemini Vision</li>
|
| 507 |
+
<li>PIL (Image Processing)</li>
|
| 508 |
+
<li>Base64 Encoding</li>
|
| 509 |
+
</ul>
|
| 510 |
+
</div>
|
| 511 |
+
</div>
|
| 512 |
+
</div>
|
| 513 |
|
| 514 |
+
<div style="margin: 20px 0; padding: 15px; background-color: #f0f8f0; border-radius: 8px;">
|
| 515 |
+
<h3>📊 Performance Metrics</h3>
|
| 516 |
+
<table style="width: 100%; border-collapse: collapse;">
|
| 517 |
+
<tr>
|
| 518 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><strong>Accuracy</strong></td>
|
| 519 |
+
<td style="padding: 8px; border: 1px solid #ddd;">85-95%</td>
|
| 520 |
+
</tr>
|
| 521 |
+
<tr>
|
| 522 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><strong>Training Time</strong></td>
|
| 523 |
+
<td style="padding: 8px; border: 1px solid #ddd;">10-15 minutes</td>
|
| 524 |
+
</tr>
|
| 525 |
+
<tr>
|
| 526 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><strong>Prediction Time</strong></td>
|
| 527 |
+
<td style="padding: 8px; border: 1px solid #ddd;">< 2 seconds</td>
|
| 528 |
+
</tr>
|
| 529 |
+
<tr>
|
| 530 |
+
<td style="padding: 8px; border: 1px solid #ddd;"><strong>Model Size</strong></td>
|
| 531 |
+
<td style="padding: 8px; border: 1px solid #ddd;">~5MB (lightweight mode)</td>
|
| 532 |
+
</tr>
|
| 533 |
+
</table>
|
| 534 |
+
</div>
|
| 535 |
|
| 536 |
+
<div style="margin: 20px 0; padding: 15px; background-color: #fff8dc; border-radius: 8px;">
|
| 537 |
+
<h3>🚀 Hướng dẫn sử dụng</h3>
|
| 538 |
<ol>
|
| 539 |
+
<li><strong>Training:</strong> Bắt đầu với tab "🚀 Model Training", click "Start Training" và đợi 10-15 phút</li>
|
| 540 |
+
<li><strong>Classification:</strong> Chuyển sang tab "🔮 Bill Classification"</li>
|
| 541 |
+
<li><strong>Upload ảnh:</strong> Kéo thả ảnh hóa đơn vào khung "Upload ảnh hóa đơn"</li>
|
| 542 |
+
<li><strong>Extract text:</strong> Click "🔍 Trích xuất mô tả từ ảnh" (cần Google AI API key)</li>
|
| 543 |
+
<li><strong>Manual input:</strong> Hoặc nhập mô tả thủ công vào text box</li>
|
| 544 |
+
<li><strong>Predict:</strong> Click "🎯 Dự đoán phân loại" để xem kết quả</li>
|
| 545 |
+
<li><strong>Results:</strong> Xem dự đoán chính + top 3 alternatives với confidence scores</li>
|
| 546 |
</ol>
|
| 547 |
</div>
|
| 548 |
|
| 549 |
+
<div style="margin: 20px 0; padding: 15px; background-color: #ffe6e6; border-radius: 8px;">
|
| 550 |
+
<h3>⚠️ Lưu ý quan trọng</h3>
|
| 551 |
+
<ul>
|
| 552 |
+
<li><strong>Google AI API:</strong> Để sử dụng tính năng trích xuất từ ảnh, cần thiết lập GOOGLE_AI_API_KEY trong environment variables</li>
|
| 553 |
+
<li><strong>Dataset:</strong> App cần file viet_receipt_categorized_label.xlsx để training</li>
|
| 554 |
+
<li><strong>Memory:</strong> Training có thể tốn nhiều RAM, nên dùng trên máy có đủ bộ nhớ</li>
|
| 555 |
+
<li><strong>Time:</strong> Quá trình training mất 10-15 phút, vui lòng kiên nhẫn</li>
|
| 556 |
+
</ul>
|
| 557 |
+
</div>
|
| 558 |
+
|
| 559 |
+
<div style="text-align: center; margin-top: 30px; padding: 20px; background: linear-gradient(45deg, #667eea, #764ba2); color: white; border-radius: 8px;">
|
| 560 |
+
<h3>🎉 Developed with ❤️ for Vietnamese NLP Community</h3>
|
| 561 |
+
<p>Powered by Hugging Face 🤗 | Google AI Studio | Gradio</p>
|
| 562 |
</div>
|
| 563 |
</div>
|
| 564 |
""")
|
| 565 |
|
| 566 |
+
# Load initial status when interface starts
|
| 567 |
+
interface.load(
|
| 568 |
+
fn=get_training_status,
|
| 569 |
+
outputs=[status_display, log_display]
|
| 570 |
+
)
|
| 571 |
|
| 572 |
return interface
|
| 573 |
|
|
|
|
| 576 |
# ====================================
|
| 577 |
if __name__ == "__main__":
|
| 578 |
print("🚀 Starting Vietnamese Receipt Classification App...")
|
| 579 |
+
print("="*60)
|
| 580 |
|
| 581 |
# Check dependencies
|
| 582 |
print("📋 Checking dependencies...")
|
|
|
|
| 584 |
if GOOGLE_AI_AVAILABLE and google_vision_model is not None:
|
| 585 |
print("✅ Google AI Vision: Ready")
|
| 586 |
else:
|
| 587 |
+
print("⚠️ Google AI Vision: Not available")
|
| 588 |
+
print(" 💡 Set GOOGLE_AI_API_KEY environment variable to enable")
|
| 589 |
|
| 590 |
# Check dataset
|
| 591 |
if os.path.exists(Config.DATA_FILE):
|
| 592 |
print(f"✅ Dataset: Found {Config.DATA_FILE}")
|
| 593 |
else:
|
| 594 |
print(f"⚠️ Dataset: {Config.DATA_FILE} not found")
|
| 595 |
+
print(" 💡 Upload dataset file to enable training")
|
| 596 |
|
| 597 |
print("🎨 Creating Gradio interface...")
|
| 598 |
app = create_interface()
|
| 599 |
|
| 600 |
print("🌐 Launching app...")
|
| 601 |
+
print("="*60)
|
| 602 |
+
|
| 603 |
+
# Launch with appropriate settings
|
| 604 |
app.launch(
|
|
|
|
| 605 |
server_name="0.0.0.0",
|
| 606 |
server_port=7860,
|
| 607 |
+
share=False, # Set to True for public sharing
|
| 608 |
+
show_error=True,
|
| 609 |
+
show_tips=True,
|
| 610 |
+
enable_queue=True
|