|
|
from __future__ import annotations |
|
|
import os, io, re, json, time, mimetypes, tempfile |
|
|
from typing import List, Union, Tuple, Any |
|
|
from PIL import Image |
|
|
import pandas as pd |
|
|
import gradio as gr |
|
|
import google.generativeai as genai |
|
|
import requests |
|
|
import fitz |
|
|
import camelot |
|
|
import pdfplumber |
|
|
import random |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_API_KEY = ["AIzaSyD2FLH3g8cqA1T0CZxETqpkM9O85SW2csA", |
|
|
"AIzaSyCRShiCasMPV1FugzPX_3V5LAz-Vjqt8FI", |
|
|
"AIzaSyAjnvvAY8if-jGRBu9jpvXKMz8U9V5IRz4", |
|
|
"AIzaSyDaWoSpgK8hKiDl6yBpcEow2Tp1bd-V5-I", |
|
|
"AIzaSyCsxR162atCCj2ssxiiVa5ejishRbyLDe8", |
|
|
"AIzaSyDRWRwwnYJktCULH8d26mzD1Lv4l0CdQws" |
|
|
] |
|
|
|
|
|
|
|
|
INTERNAL_MODEL_MAP = { |
|
|
"Gemini 2.5 Flash": "gemini-2.5-flash", |
|
|
"Gemini 2.5 Pro": "gemini-2.5-pro", |
|
|
} |
|
|
EXTERNAL_MODEL_NAME = "prithivMLmods/Camel-Doc-OCR-062825 (External)" |
|
|
PROMPT_FREIGHT_JSON = """ |
|
|
Please analyze the freight rate table in the file I provide and convert it into JSON in the following structure: |
|
|
{ |
|
|
"shipping_line": "...", |
|
|
"shipping_line_code": "...", |
|
|
"shipping_line_reason": "Why this carrier is chosen?", |
|
|
"fee_type": "Air Freight", |
|
|
"valid_from": ..., |
|
|
"valid_to": ..., |
|
|
"charges": [ |
|
|
{ |
|
|
"frequency": "...", |
|
|
"package_type": "...", |
|
|
"aircraft_type": "...", |
|
|
"direction": "Export or Import or null", |
|
|
"origin": "...", |
|
|
"destination": "...", |
|
|
"charge_name": "...", |
|
|
"charge_code": "charge_code": "GCR, DGR, PER, etc. (Use IATA Code DO NOT use flight number)", |
|
|
"charge_code_reason": "...", |
|
|
"cargo_type": "...", |
|
|
"currency": "...", |
|
|
"transit": "...", |
|
|
"transit_time": "...", |
|
|
"weight_breaks": { |
|
|
"M": ..., |
|
|
"N": ..., |
|
|
"+45kg": ..., |
|
|
"+100kg": ..., |
|
|
"+300kg": ..., |
|
|
"+500kg": ..., |
|
|
"+1000kg": ..., |
|
|
"other": { |
|
|
key: value |
|
|
}, |
|
|
"weight_breaks_reason":"Why chosen weight_breaks?" |
|
|
}, |
|
|
"remark": "..." |
|
|
} |
|
|
], |
|
|
"local_charges": [ |
|
|
{ |
|
|
"charge_name": "...", |
|
|
"charge_code": "...", |
|
|
"unit": "...", |
|
|
"amount": ..., |
|
|
"remark": "..." |
|
|
} |
|
|
] |
|
|
} |
|
|
### Date rules |
|
|
- valid_from format: |
|
|
- `DD/MM/YYYY` (if full date) |
|
|
- `01/MM/YYYY` (if month+year only) |
|
|
- `01/01/YYYY` (if year only) |
|
|
- `UFN` if missing |
|
|
- valid_to: |
|
|
- exact `DD/MM/YYYY` if present |
|
|
- else `UFN` |
|
|
STRICT RULES: |
|
|
- ONLY return a single JSON object as specified above. |
|
|
- All rates must exactly match the corresponding weight break columns (M,N,45kg, 100kg, 300kg, 500kg, 1000kg, etc.). set null if N/A. No assumptions or interpolations. |
|
|
- If the table shows "RQ" or similar, set value as "RQST". |
|
|
- Group same-price destinations into one record separated by "/". |
|
|
- Always use IATA code for origin and destination. |
|
|
- Flight number (e.g. ZH118) is not charge code. |
|
|
- Frequency: D[1-7]; 'Daily' = D1234567. Join multiple (e.g. D3,D4→D34). |
|
|
- If local charges exist, list them. |
|
|
- If validity missing, set null. |
|
|
- Direction: Export if origin is Vietnam (SGN, HAN, DAD...), else Import. |
|
|
- Provide short plain English reasons for "shipping_line_reason" & "charge_code_reason". |
|
|
- Replace commas in remarks with semicolons. |
|
|
- Only return JSON. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def _read_file_bytes(upload: Union[str, os.PathLike, dict, object] | None) -> bytes: |
|
|
if upload is None: |
|
|
raise ValueError("No file uploaded.") |
|
|
if isinstance(upload, (str, os.PathLike)): |
|
|
with open(upload, "rb") as f: |
|
|
return f.read() |
|
|
if isinstance(upload, dict) and "path" in upload: |
|
|
with open(upload["path"], "rb") as f: |
|
|
return f.read() |
|
|
if hasattr(upload, "read"): |
|
|
return upload.read() |
|
|
raise TypeError(f"Unsupported file object: {type(upload)}") |
|
|
|
|
|
def _guess_name_and_mime(file, file_bytes: bytes) -> Tuple[str, str]: |
|
|
filename = os.path.basename(file.name if hasattr(file, "name") else str(file)) |
|
|
mime, _ = mimetypes.guess_type(filename) |
|
|
if not mime and file_bytes[:4] == b"%PDF": |
|
|
mime = "application/pdf" |
|
|
return filename, mime or "application/octet-stream" |
|
|
|
|
|
def extract_pdf_tables(file_path: str) -> pd.DataFrame: |
|
|
""" |
|
|
Extract bảng PDF bằng Camelot (từng trang): |
|
|
- Thử lattice |
|
|
- Nếu thất bại → fallback stream |
|
|
- Gộp tất cả |
|
|
""" |
|
|
import camelot |
|
|
all_dfs = [] |
|
|
|
|
|
|
|
|
import fitz |
|
|
total_pages = len(fitz.open(file_path)) |
|
|
print(f"📄 Tổng số trang: {total_pages}") |
|
|
|
|
|
for page_no in range(1, total_pages + 1): |
|
|
print(f"🔍 Đang xử lý trang {page_no}...") |
|
|
dfs_this_page = [] |
|
|
|
|
|
|
|
|
try: |
|
|
tables = camelot.read_pdf( |
|
|
file_path, flavor="lattice", |
|
|
pages=str(page_no), strip_text="\n", line_scale=40 |
|
|
) |
|
|
if tables and tables.n > 0: |
|
|
for t in tables: |
|
|
if t.shape[0] > 0: |
|
|
dfs_this_page.append(t.df) |
|
|
print(f"✅ Trang {page_no}: Lattice thành công ({tables.n} bảng).") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Trang {page_no} lattice lỗi: {e}") |
|
|
|
|
|
|
|
|
if not dfs_this_page: |
|
|
try: |
|
|
tables = camelot.read_pdf( |
|
|
file_path, flavor="stream", |
|
|
pages=str(page_no), edge_tol=200, row_tol=10 |
|
|
) |
|
|
if tables and tables.n > 0: |
|
|
for t in tables: |
|
|
if t.shape[0] > 0: |
|
|
dfs_this_page.append(t.df) |
|
|
print(f"✅ Trang {page_no}: Stream thành công ({tables.n} bảng).") |
|
|
except Exception as e: |
|
|
print(f"❌ Trang {page_no} stream lỗi: {e}") |
|
|
|
|
|
if dfs_this_page: |
|
|
all_dfs.extend(dfs_this_page) |
|
|
else: |
|
|
print(f"🚫 Trang {page_no}: Không phát hiện bảng.") |
|
|
|
|
|
if not all_dfs: |
|
|
print("❌ Không tìm thấy bảng trong toàn bộ PDF.") |
|
|
return pd.DataFrame() |
|
|
|
|
|
df_final = pd.concat(all_dfs, ignore_index=True) |
|
|
if all(str(c).isdigit() for c in df_final.columns): |
|
|
df_final.columns = df_final.iloc[0] |
|
|
df_final = df_final[1:] |
|
|
df_final = df_final.dropna(how="all").reset_index(drop=True) |
|
|
print(f"✅ Tổng hợp: {len(df_final)} dòng, {len(df_final.columns)} cột.") |
|
|
return df_final |
|
|
|
|
|
def extract_pdf_note(file_path: str) -> str: |
|
|
""" |
|
|
Dùng pdfplumber để lấy phần text cuối tài liệu (note, remark...). |
|
|
Chỉ lấy từ 10 dòng cuối của trang cuối. |
|
|
""" |
|
|
try: |
|
|
with pdfplumber.open(file_path) as pdf: |
|
|
last_page = pdf.pages[-1] |
|
|
text = (last_page.extract_text() or "").strip() |
|
|
lines = text.splitlines() |
|
|
note_text = "\n".join(lines[-12:]) |
|
|
print(f"📝 Extracted note text thành công.{note_text}") |
|
|
return note_text |
|
|
except Exception as e: |
|
|
print(f"⚠️ extract_pdf_note lỗi: {e}") |
|
|
return "" |
|
|
def extract_airline_header_via_ocr(file_path: str) -> str: |
|
|
""" |
|
|
Dùng Gemini OCR nhận diện hãng bay ở trang đầu PDF. |
|
|
⚡ Tối ưu: chỉ lấy 1 trang đầu, DPI=120, JPEG quality=60 để giảm dung lượng. |
|
|
""" |
|
|
import google.generativeai as genai |
|
|
from PIL import Image |
|
|
import fitz, io, tempfile, os |
|
|
|
|
|
|
|
|
api_key = random.choice(DEFAULT_API_KEY) |
|
|
genai.configure(api_key=api_key) |
|
|
model = genai.GenerativeModel("gemini-2.5-flash") |
|
|
|
|
|
|
|
|
pdf = fitz.open(file_path) |
|
|
pix = pdf[0].get_pixmap(dpi=120) |
|
|
img = Image.open(io.BytesIO(pix.tobytes("png"))).convert("RGB") |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: |
|
|
img.save(tmp.name, format="JPEG", quality=60, optimize=True) |
|
|
img_path = tmp.name |
|
|
|
|
|
|
|
|
uploaded = genai.upload_file(path=img_path, mime_type="image/jpeg") |
|
|
|
|
|
|
|
|
prompt = """ |
|
|
Identify from this airline rate sheet: |
|
|
- Airline name (e.g. Qatar Airways, Turkish Airlines) |
|
|
- Airline code (e.g. QR, TK, EK, VN) |
|
|
- Title (e.g. SGN PRICING NOV25) |
|
|
- Validity info (e.g. Effective from 01 Nov 2025, Until Further Notice) |
|
|
Return JSON with fields: airline_name, airline_code, title, valid_from, valid_to. |
|
|
""" |
|
|
|
|
|
resp = model.generate_content([prompt, uploaded]) |
|
|
genai.delete_file(uploaded.name) |
|
|
|
|
|
result = getattr(resp, "text", "").strip() |
|
|
print("🛫 OCR header (compressed):", result) |
|
|
return result |
|
|
def call_gemini_with_prompt( |
|
|
header: str, |
|
|
content_text: str, |
|
|
note_text: str, |
|
|
question: str, |
|
|
model_choice: str, |
|
|
temperature: float, |
|
|
top_p: float |
|
|
): |
|
|
""" |
|
|
Gửi header + bảng CSV + note vào Gemini. |
|
|
Ưu tiên: nếu user nhập prompt riêng → dùng prompt đó, ngược lại dùng PROMPT_FREIGHT_JSON. |
|
|
Header (nếu có) sẽ được chèn thêm vào đầu để giúp model nhận diện hãng bay, thời gian hiệu lực, v.v. |
|
|
""" |
|
|
|
|
|
api_key = random.choice(DEFAULT_API_KEY) |
|
|
|
|
|
genai.configure(api_key=api_key) |
|
|
|
|
|
model = genai.GenerativeModel( |
|
|
model_name=INTERNAL_MODEL_MAP.get(model_choice, "gemini-2.5-flash"), |
|
|
generation_config={ |
|
|
"temperature": float(temperature), |
|
|
"top_p": float(top_p) |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
base_prompt = question.strip() if question and question.strip() else PROMPT_FREIGHT_JSON |
|
|
|
|
|
|
|
|
prompt_parts = [base_prompt] |
|
|
|
|
|
if header and header.strip(): |
|
|
prompt_parts.append(f""" |
|
|
### Header information (from first page OCR or PDF header): |
|
|
{header} |
|
|
""") |
|
|
|
|
|
prompt_parts.append(f""" |
|
|
### Extracted table data (CSV format): |
|
|
{content_text} |
|
|
""") |
|
|
|
|
|
if note_text and note_text.strip(): |
|
|
prompt_parts.append(f""" |
|
|
### Notes or remarks extracted from the PDF: |
|
|
{note_text} |
|
|
""") |
|
|
|
|
|
prompt_parts.append(""" |
|
|
Please analyze all data (header + table + notes) and generate the final JSON output |
|
|
following the defined schema above. Ensure that any airline, date, or rule from header/note |
|
|
is merged into the JSON result (e.g. shipping_line, valid_from, valid_to, remarks, etc.). |
|
|
""") |
|
|
|
|
|
full_prompt = "\n".join(prompt_parts) |
|
|
|
|
|
print("🧠 Sending full prompt (with header if available) to Gemini...") |
|
|
response = model.generate_content(full_prompt) |
|
|
result_text = getattr(response, "text", str(response)) |
|
|
|
|
|
return result_text |
|
|
|
|
|
|
|
|
|
|
|
def run_process(file, question, model_choice, temperature, top_p, external_api_url): |
|
|
try: |
|
|
if file is None: |
|
|
return "❌ No file uploaded.", None |
|
|
|
|
|
file_bytes = _read_file_bytes(file) |
|
|
filename, mime = _guess_name_and_mime(file, file_bytes) |
|
|
print(f"[UPLOAD] {filename} ({mime})") |
|
|
|
|
|
if mime == "application/pdf": |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: |
|
|
tmp.write(file_bytes) |
|
|
tmp_path = tmp.name |
|
|
|
|
|
|
|
|
df = extract_pdf_tables(tmp_path) |
|
|
note_text = extract_pdf_note(tmp_path) |
|
|
header = extract_airline_header_via_ocr(tmp_path) |
|
|
if not df.empty: |
|
|
csv_text = df.to_csv(index=False) |
|
|
print("✅ Gửi Gemini để sinh JSON...") |
|
|
message = call_gemini_with_prompt(header, csv_text, note_text, question, model_choice, temperature, top_p) |
|
|
return message, None |
|
|
else: |
|
|
print("⚠️ Không có bảng hợp lệ, fallback OCR Gemini.") |
|
|
return run_process_internal_base_v2(file_bytes, filename, mime, question, model_choice, temperature, top_p) |
|
|
|
|
|
|
|
|
return run_process_internal_base_v2(file_bytes, filename, mime, question, model_choice, temperature, top_p) |
|
|
|
|
|
except Exception as e: |
|
|
return f"ERROR: {type(e).__name__}: {e}", None |
|
|
def run_process_internal_base_v2(file_bytes, filename, mime, question, model_choice, temperature, top_p, batch_size=3): |
|
|
|
|
|
api_key = random.choice(DEFAULT_API_KEY) |
|
|
if not api_key: |
|
|
return "ERROR: Missing GOOGLE_API_KEY.", None |
|
|
genai.configure(api_key=api_key) |
|
|
model_name = INTERNAL_MODEL_MAP.get(model_choice, "gemini-2.5-flash") |
|
|
model = genai.GenerativeModel(model_name=model_name, |
|
|
generation_config={"temperature": float(temperature), "top_p": float(top_p)}) |
|
|
|
|
|
if file_bytes[:4] == b"%PDF": |
|
|
pages = pdf_to_images(file_bytes) |
|
|
else: |
|
|
pages = [Image.open(io.BytesIO(file_bytes))] |
|
|
|
|
|
user_prompt = (question or "").strip() or PROMPT_FREIGHT_JSON |
|
|
all_json_results, all_text_results = [], [] |
|
|
previous_header_json = None |
|
|
|
|
|
def _safe_text(resp): |
|
|
try: |
|
|
return resp.text |
|
|
except: |
|
|
return "" |
|
|
|
|
|
for i in range(0, len(pages), batch_size): |
|
|
batch = pages[i:i+batch_size] |
|
|
uploaded = [] |
|
|
for im in batch: |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: |
|
|
im.save(tmp.name) |
|
|
up = genai.upload_file(path=tmp.name, mime_type="image/png") |
|
|
up = genai.get_file(up.name) |
|
|
uploaded.append(up) |
|
|
|
|
|
context_prompt = user_prompt |
|
|
resp = model.generate_content([context_prompt] + uploaded) |
|
|
text = _safe_text(resp) |
|
|
all_text_results.append(text) |
|
|
for up in uploaded: |
|
|
try: |
|
|
genai.delete_file(up.name) |
|
|
except: |
|
|
pass |
|
|
|
|
|
return "\n\n".join(all_text_results), None |
|
|
|
|
|
def main(): |
|
|
with gr.Blocks(title="OCR Multi-Agent System") as demo: |
|
|
file = gr.File(label="Upload PDF/Image") |
|
|
question = gr.Textbox(label="Prompt", lines=2) |
|
|
model_choice = gr.Dropdown(choices=[*INTERNAL_MODEL_MAP.keys(), EXTERNAL_MODEL_NAME], |
|
|
value="Gemini 2.5 Flash", label="Model") |
|
|
temperature = gr.Slider(0.0, 2.0, value=0.2, step=0.05) |
|
|
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01) |
|
|
external_api_url = gr.Textbox(label="External API URL", visible=False) |
|
|
output_text = gr.Code(label="Output", language="json") |
|
|
run_btn = gr.Button("🚀 Process") |
|
|
|
|
|
run_btn.click( |
|
|
run_process, |
|
|
inputs=[file, question, model_choice, temperature, top_p, external_api_url], |
|
|
outputs=[output_text, gr.State()] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
demo = main() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|