Upload 14 files
Browse files- .gitattributes +1 -35
- LICENSE +21 -0
- README.md +49 -3
- app.py +63 -0
- config.yaml +10 -0
- entity_tagger.py +12 -0
- examples/demo_commands.txt +15 -0
- examples/invoice_sample.pdf +6 -0
- models/layoutlm_processor.py +8 -0
- ocr_extractor.py +43 -0
- pdf_loader.py +10 -0
- requirements.txt +12 -0
- summarize_doc.py +36 -0
- utils.py +15 -0
.gitattributes
CHANGED
|
@@ -1,35 +1 @@
|
|
| 1 |
-
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
* text=auto
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 hmnshudhmn24
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,49 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
license: mit
|
| 5 |
+
tags:
|
| 6 |
+
- document-question-answering
|
| 7 |
+
- ocr
|
| 8 |
+
- summarization
|
| 9 |
+
- document-ai
|
| 10 |
+
pipeline_tag: document-question-answering
|
| 11 |
+
model_name: docintel
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# 🧾 DOCINTEL — Document AI (Donut-based)
|
| 15 |
+
|
| 16 |
+
**DOCINTEL** extracts structured insights from scanned PDFs and images using **naver-clova-ix/donut-base** (Donut). It supports OCR fallback, entity extraction, and document summarization via Donut on page images.
|
| 17 |
+
|
| 18 |
+
> ⚠️ Install system dependencies: `poppler` and `tesseract` for pdf2image and pytesseract respectively.
|
| 19 |
+
|
| 20 |
+
## Quickstart
|
| 21 |
+
|
| 22 |
+
1. Create venv & install dependencies:
|
| 23 |
+
```bash
|
| 24 |
+
python -m venv venv
|
| 25 |
+
source venv/bin/activate # Windows: venv\Scripts\activate
|
| 26 |
+
pip install -r requirements.txt
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
2. Run API server:
|
| 30 |
+
```bash
|
| 31 |
+
uvicorn app:app --host 0.0.0.0 --port 8000
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
3. Upload a PDF and call endpoints (see examples/demo_commands.txt).
|
| 35 |
+
|
| 36 |
+
## Files
|
| 37 |
+
- `ocr_extractor.py` — PDF→images→OCR pipeline
|
| 38 |
+
- `pdf_loader.py` — extract embedded text from PDFs
|
| 39 |
+
- `entity_tagger.py` — regex-based entity extraction
|
| 40 |
+
- `summarize_doc.py` — DONUT-based summarizer for page images
|
| 41 |
+
- `app.py` — FastAPI server with upload/summary endpoints
|
| 42 |
+
|
| 43 |
+
## Notes
|
| 44 |
+
- Donut requires vision-encoder-decoder inference which may need GPU for speed.
|
| 45 |
+
- For text-only PDFs consider using `extract_text_from_pdf` then a text summarizer instead of Donut.
|
| 46 |
+
- This repo is a prototype/demo. Validate on your data before production use.
|
| 47 |
+
|
| 48 |
+
## License
|
| 49 |
+
MIT
|
app.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI app for DOCINTEL: upload PDF, extract text/OCR, get entities, summarize."""
|
| 2 |
+
import os, uuid, tempfile
|
| 3 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException
|
| 4 |
+
from fastapi.responses import JSONResponse
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from utils import ensure_dir, load_config, save_json
|
| 7 |
+
from ocr_extractor import extract_full_text, pdf_to_images
|
| 8 |
+
from entity_tagger import extract_entities
|
| 9 |
+
from summarize_doc import summarize_image, summarize_text
|
| 10 |
+
|
| 11 |
+
app = FastAPI(title='DOCINTEL API')
|
| 12 |
+
|
| 13 |
+
cfg = load_config()
|
| 14 |
+
STORAGE = cfg.get('storage_dir', './storage')
|
| 15 |
+
ensure_dir(STORAGE)
|
| 16 |
+
|
| 17 |
+
class QARequest(BaseModel):
|
| 18 |
+
question: str
|
| 19 |
+
|
| 20 |
+
@app.post('/upload_pdf')
|
| 21 |
+
async def upload_pdf(file: UploadFile = File(...)):
|
| 22 |
+
if not file.filename.lower().endswith('.pdf'):
|
| 23 |
+
raise HTTPException(status_code=400, detail='Only PDF files are allowed')
|
| 24 |
+
doc_id = str(uuid.uuid4())
|
| 25 |
+
save_path = os.path.join(STORAGE, f"{doc_id}_{file.filename}")
|
| 26 |
+
with open(save_path, 'wb') as f:
|
| 27 |
+
f.write(await file.read())
|
| 28 |
+
return {'doc_id': doc_id, 'filename': file.filename, 'path': save_path}
|
| 29 |
+
|
| 30 |
+
@app.get('/doc/{doc_id}/text')
|
| 31 |
+
def get_text(doc_id: str):
|
| 32 |
+
files = [f for f in os.listdir(STORAGE) if f.startswith(doc_id+'_')]
|
| 33 |
+
if not files:
|
| 34 |
+
raise HTTPException(status_code=404, detail='Document not found')
|
| 35 |
+
path = os.path.join(STORAGE, files[0])
|
| 36 |
+
text, ocr_pages = extract_full_text(path)
|
| 37 |
+
return {'doc_id': doc_id, 'text': text, 'ocr_pages_count': len(ocr_pages)}
|
| 38 |
+
|
| 39 |
+
@app.get('/doc/{doc_id}/entities')
|
| 40 |
+
def get_entities(doc_id: str):
|
| 41 |
+
files = [f for f in os.listdir(STORAGE) if f.startswith(doc_id+'_')]
|
| 42 |
+
if not files:
|
| 43 |
+
raise HTTPException(status_code=404, detail='Document not found')
|
| 44 |
+
path = os.path.join(STORAGE, files[0])
|
| 45 |
+
text, _ = extract_full_text(path)
|
| 46 |
+
ents = extract_entities(text)
|
| 47 |
+
return JSONResponse(content={'doc_id': doc_id, 'entities': ents})
|
| 48 |
+
|
| 49 |
+
@app.post('/doc/{doc_id}/summarize')
|
| 50 |
+
def post_summarize(doc_id: str):
|
| 51 |
+
files = [f for f in os.listdir(STORAGE) if f.startswith(doc_id+'_')]
|
| 52 |
+
if not files:
|
| 53 |
+
raise HTTPException(status_code=404, detail='Document not found')
|
| 54 |
+
path = os.path.join(STORAGE, files[0])
|
| 55 |
+
# convert to images and summarize first page with DONUT
|
| 56 |
+
pages = pdf_to_images(path, out_dir=tempfile.mkdtemp())
|
| 57 |
+
if not pages:
|
| 58 |
+
text, _ = extract_full_text(path)
|
| 59 |
+
summary = summarize_text(text)
|
| 60 |
+
return {'doc_id': doc_id, 'summary': summary}
|
| 61 |
+
# use first page image
|
| 62 |
+
summary = summarize_image(pages[0])
|
| 63 |
+
return {'doc_id': doc_id, 'summary': summary}
|
config.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
name: "naver-clova-ix/donut-base"
|
| 3 |
+
task: "document-question-answering"
|
| 4 |
+
ocr:
|
| 5 |
+
lang: "eng"
|
| 6 |
+
dpi: 300
|
| 7 |
+
server:
|
| 8 |
+
host: "0.0.0.0"
|
| 9 |
+
port: 8000
|
| 10 |
+
storage_dir: "./storage"
|
entity_tagger.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Simple regex-based entity extraction for demo purposes."""
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
def extract_entities(text):
|
| 5 |
+
entities = {}
|
| 6 |
+
emails = re.findall(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b", text)
|
| 7 |
+
dates = re.findall(r"\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b", text)
|
| 8 |
+
amounts = re.findall(r"\b\$?\d{1,3}(?:[.,]\d{3})*(?:[.,]\d+)?\s?(?:USD|INR|EUR|Rs|\$)?\b", text)
|
| 9 |
+
entities['emails'] = list(dict.fromkeys(emails))
|
| 10 |
+
entities['dates'] = list(dict.fromkeys(dates))
|
| 11 |
+
entities['amounts'] = list(dict.fromkeys([a.strip() for a in amounts if a.strip()]))
|
| 12 |
+
return entities
|
examples/demo_commands.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Example commands for DOCINTEL
|
| 2 |
+
# 1) Start API server
|
| 3 |
+
uvicorn app:app --host 0.0.0.0 --port 8000
|
| 4 |
+
|
| 5 |
+
# 2) Upload PDF
|
| 6 |
+
curl -X POST "http://127.0.0.1:8000/upload_pdf" -F "file=@examples/invoice_sample.pdf"
|
| 7 |
+
|
| 8 |
+
# 3) Get extracted text
|
| 9 |
+
curl "http://127.0.0.1:8000/doc/<DOC_ID>/text"
|
| 10 |
+
|
| 11 |
+
# 4) Get entities
|
| 12 |
+
curl "http://127.0.0.1:8000/doc/<DOC_ID>/entities"
|
| 13 |
+
|
| 14 |
+
# 5) Summarize document
|
| 15 |
+
curl -X POST "http://127.0.0.1:8000/doc/<DOC_ID>/summarize"
|
examples/invoice_sample.pdf
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
%PDF-1.4
|
| 2 |
+
%\xe2\xe3\xcf\xd3
|
| 3 |
+
1 0 obj<<>>endobj
|
| 4 |
+
trailer
|
| 5 |
+
<<>>
|
| 6 |
+
%%EOF
|
models/layoutlm_processor.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LayoutLM helper (optional) - provided for completeness but not used by default.
|
| 2 |
+
"""
|
| 3 |
+
from transformers import LayoutLMv3Processor, LayoutLMv3ForQuestionAnswering
|
| 4 |
+
|
| 5 |
+
def load_layoutlm(model_name='microsoft/layoutlmv3-base'):
|
| 6 |
+
proc = LayoutLMv3Processor.from_pretrained(model_name)
|
| 7 |
+
model = LayoutLMv3ForQuestionAnswering.from_pretrained(model_name)
|
| 8 |
+
return proc, model
|
ocr_extractor.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OCR extraction using pdf2image + pytesseract for scanned pages."""
|
| 2 |
+
from pdf2image import convert_from_path
|
| 3 |
+
import pytesseract
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from utils import load_config
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
def pdf_to_images(pdf_path, dpi=None, out_dir=None):
|
| 9 |
+
cfg = load_config()
|
| 10 |
+
dpi = dpi or cfg.get('ocr', {}).get('dpi', 300)
|
| 11 |
+
pages = convert_from_path(pdf_path, dpi=dpi)
|
| 12 |
+
paths = []
|
| 13 |
+
if out_dir:
|
| 14 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 15 |
+
for i, img in enumerate(pages, start=1):
|
| 16 |
+
path = os.path.join(out_dir or '.', f'page_{i}.png')
|
| 17 |
+
img.save(path, 'PNG')
|
| 18 |
+
paths.append(path)
|
| 19 |
+
return paths
|
| 20 |
+
|
| 21 |
+
def ocr_image(path, lang=None):
|
| 22 |
+
cfg = load_config()
|
| 23 |
+
lang = lang or cfg.get('ocr', {}).get('lang', 'eng')
|
| 24 |
+
img = Image.open(path)
|
| 25 |
+
text = pytesseract.image_to_string(img, lang=lang)
|
| 26 |
+
return text
|
| 27 |
+
|
| 28 |
+
def extract_full_text(pdf_path, do_ocr=True):
|
| 29 |
+
# Try embedded text first
|
| 30 |
+
try:
|
| 31 |
+
from pdf_loader import extract_text_from_pdf
|
| 32 |
+
txt = extract_text_from_pdf(pdf_path)
|
| 33 |
+
if txt and len(txt) > 200:
|
| 34 |
+
return txt, [] # return text and empty ocr pages list
|
| 35 |
+
except Exception:
|
| 36 |
+
txt = ''
|
| 37 |
+
# fallback to OCR
|
| 38 |
+
pages = pdf_to_images(pdf_path, out_dir='./temp_pages')
|
| 39 |
+
ocr_texts = []
|
| 40 |
+
for p in pages:
|
| 41 |
+
ocr_texts.append(ocr_image(p))
|
| 42 |
+
full = '\n\n'.join(ocr_texts)
|
| 43 |
+
return full, ocr_texts
|
pdf_loader.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PDF text extraction using PyMuPDF (fitz) for embedded text layers."""
|
| 2 |
+
import fitz
|
| 3 |
+
|
| 4 |
+
def extract_text_from_pdf(pdf_path):
|
| 5 |
+
doc = fitz.open(pdf_path)
|
| 6 |
+
texts = []
|
| 7 |
+
for page in doc:
|
| 8 |
+
txt = page.get_text('text') or ''
|
| 9 |
+
texts.append(txt)
|
| 10 |
+
return '\n\n'.join(texts)
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers>=4.30.0
|
| 2 |
+
torch>=1.12.0
|
| 3 |
+
pdf2image
|
| 4 |
+
pytesseract
|
| 5 |
+
Pillow
|
| 6 |
+
PyMuPDF
|
| 7 |
+
fastapi
|
| 8 |
+
uvicorn[standard]
|
| 9 |
+
python-multipart
|
| 10 |
+
pyyaml
|
| 11 |
+
requests
|
| 12 |
+
tqdm
|
summarize_doc.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Summarization using DONUT (naver-clova-ix/donut-base) via Hugging Face.
|
| 2 |
+
This module uses Donut's processor and VisionEncoderDecoderModel for docVQA-style prompts.
|
| 3 |
+
"""
|
| 4 |
+
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from utils import load_config
|
| 7 |
+
|
| 8 |
+
_processor = None
|
| 9 |
+
_model = None
|
| 10 |
+
|
| 11 |
+
def _init(model_name):
|
| 12 |
+
global _processor, _model
|
| 13 |
+
if _processor is None or _model is None:
|
| 14 |
+
_processor = DonutProcessor.from_pretrained(model_name)
|
| 15 |
+
_model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
| 16 |
+
return _processor, _model
|
| 17 |
+
|
| 18 |
+
def summarize_image(image_path, model_name=None, max_length=250):
|
| 19 |
+
cfg = load_config()
|
| 20 |
+
model_name = model_name or cfg.get('model', {}).get('name')
|
| 21 |
+
processor, model = _init(model_name)
|
| 22 |
+
image = Image.open(image_path).convert('RGB')
|
| 23 |
+
task_prompt = '<s_docvqa><s_question>Summarize the document:</s_question>'
|
| 24 |
+
inputs = processor(image, task_prompt, return_tensors='pt')
|
| 25 |
+
output = model.generate(**inputs, max_new_tokens=max_length)
|
| 26 |
+
decoded = processor.batch_decode(output, skip_special_tokens=True)[0]
|
| 27 |
+
return decoded
|
| 28 |
+
|
| 29 |
+
def summarize_text(text, chunk_size=1000, model_name=None):
|
| 30 |
+
# naive: summarize by extracting first chunk and running model on placeholder image (not ideal for text-only)
|
| 31 |
+
# For text-heavy docs, use text summarization pipeline instead; here we return a simple extractive summary.
|
| 32 |
+
lines = [l.strip() for l in text.split('\n') if l.strip()]
|
| 33 |
+
if not lines:
|
| 34 |
+
return ''
|
| 35 |
+
summary = ' '.join(lines[:min(5, len(lines))])
|
| 36 |
+
return summary
|
utils.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import os, yaml, json
|
| 3 |
+
|
| 4 |
+
def load_config(path='config.yaml'):
|
| 5 |
+
p = Path(path)
|
| 6 |
+
if not p.exists():
|
| 7 |
+
raise FileNotFoundError(f'Config not found: {path}')
|
| 8 |
+
return yaml.safe_load(p.read_text())
|
| 9 |
+
|
| 10 |
+
def ensure_dir(path):
|
| 11 |
+
os.makedirs(path, exist_ok=True)
|
| 12 |
+
|
| 13 |
+
def save_json(obj, path):
|
| 14 |
+
ensure_dir(Path(path).parent)
|
| 15 |
+
Path(path).write_text(json.dumps(obj, indent=2), encoding='utf-8')
|