File size: 6,398 Bytes
d49f9c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from fastapi import APIRouter, File, UploadFile, Form, HTTPException, status
from fastapi.responses import JSONResponse
from config import settings
from PIL import Image
import urllib.request
from io import BytesIO
import utils
import os
import time
from functools import lru_cache
from paddleocr import PaddleOCR
from pdf2image import convert_from_bytes
import io
import json
from routers.data_utils import merge_data
from routers.data_utils import store_data
import motor.motor_asyncio
from typing import Optional
from pymongo import ASCENDING
from pymongo.errors import DuplicateKeyError


router = APIRouter()

client = None
db = None


async def create_unique_index(collection, *fields):
    index_fields = [(field, 1) for field in fields]
    return await collection.create_index(index_fields, unique=True)


async def create_ttl_index(db, collection_name, field, expire_after_seconds):
    # Get a reference to your collection
    collection = db[collection_name]
    # Create an index on the specified field
    index_result = await collection.create_index([(field, ASCENDING)], expireAfterSeconds=expire_after_seconds)
    print(f"TTL index created or already exists: {index_result}")


@router.on_event("startup")
async def startup_event():
    if "MONGODB_URL" in os.environ:
        global client
        global db
        client = motor.motor_asyncio.AsyncIOMotorClient(os.environ.get("MONGODB_URL"))
        db = client.chatgpt_plugin

        index_result = await create_unique_index(db['uploads'], 'receipt_key')
        print(f"Unique index created or already exists: {index_result}")
        index_result = await create_unique_index(db['receipts'], 'user', 'receipt_key')
        print(f"Unique index created or already exists: {index_result}")
        await create_ttl_index(db, 'uploads', 'created_at', 15*60)

        print("Connected to MongoDB from OCR!")


@router.on_event("shutdown")
async def shutdown_event():
    if "MONGODB_URL" in os.environ:
        global client
        client.close()


@lru_cache(maxsize=1)
def load_ocr_model():
    model = PaddleOCR(use_angle_cls=True, lang='en')
    return model


def invoke_ocr(doc, content_type):
    worker_pid = os.getpid()
    print(f"Handling OCR request with worker PID: {worker_pid}")
    start_time = time.time()

    model = load_ocr_model()

    bytes_img = io.BytesIO()

    format_img = "JPEG"
    if content_type == "image/png":
        format_img = "PNG"

    doc.save(bytes_img, format=format_img)
    bytes_data = bytes_img.getvalue()
    bytes_img.close()

    result = model.ocr(bytes_data, cls=True)

    values = []
    for idx in range(len(result)):
        res = result[idx]
        for line in res:
            values.append(line)

    values = merge_data(values)

    end_time = time.time()
    processing_time = end_time - start_time
    print(f"OCR done, worker PID: {worker_pid}")

    return values, processing_time

@router.post("/ocr")
async def run_ocr(file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None),
                  post_processing: Optional[bool] = Form(False), sparrow_key: str = Form(None)):

    if sparrow_key != settings.sparrow_key:
        return {"error": "Invalid Sparrow key."}

    result = None
    if file:
        if file.content_type in ["image/jpeg", "image/jpg", "image/png"]:
            doc = Image.open(BytesIO(await file.read()))
        elif file.content_type == "application/pdf":
            pdf_bytes = await file.read()
            pages = convert_from_bytes(pdf_bytes, 300)
            doc = pages[0]
        else:
            return {"error": "Invalid file type. Only JPG/PNG images and PDF are allowed."}

        result, processing_time = invoke_ocr(doc, file.content_type)

        utils.log_stats(settings.ocr_stats_file, [processing_time, file.filename])
        print(f"Processing time OCR: {processing_time:.2f} seconds")

        if post_processing and "MONGODB_URL" in os.environ:
            print("Postprocessing...")
            try:
                result = await store_data(result, db)
            except DuplicateKeyError:
                return HTTPException(status_code=400, detail=f"Duplicate data.")
            print(f"Stored data with key: {result}")
    elif image_url:
        # test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
        # test PDF: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/receipts/2021/us/bestbuy-20211211_006.pdf
        with urllib.request.urlopen(image_url) as response:
            content_type = response.info().get_content_type()

            if content_type in ["image/jpeg", "image/jpg", "image/png"]:
                doc = Image.open(BytesIO(response.read()))
            elif content_type == "application/octet-stream":
                pdf_bytes = response.read()
                pages = convert_from_bytes(pdf_bytes, 300)
                doc = pages[0]
            else:
                return {"error": "Invalid file type. Only JPG/PNG images and PDF are allowed."}

        result, processing_time = invoke_ocr(doc, content_type)

        # parse file name from url
        file_name = image_url.split("/")[-1]
        utils.log_stats(settings.ocr_stats_file, [processing_time, file_name])
        print(f"Processing time OCR: {processing_time:.2f} seconds")

        if post_processing and "MONGODB_URL" in os.environ:
            print("Postprocessing...")
            try:
                result = await store_data(result, db)
            except DuplicateKeyError:
                return HTTPException(status_code=400, detail=f"Duplicate data.")
            print(f"Stored data with key: {result}")
    else:
        result = {"info": "No input provided"}

    if result is None:
        raise HTTPException(status_code=400, detail=f"Failed to process the input.")

    return JSONResponse(status_code=status.HTTP_200_OK, content=result)


@router.get("/statistics")
async def get_statistics():
    file_path = settings.ocr_stats_file

    # Check if the file exists, and read its content
    if os.path.exists(file_path):
        with open(file_path, 'r') as file:
            try:
                content = json.load(file)
            except json.JSONDecodeError:
                content = []
    else:
        content = []

    return content