File size: 2,973 Bytes
26ab692
 
 
 
 
 
 
a95b973
26ab692
a95b973
26ab692
 
a95b973
 
 
 
 
 
26ab692
 
a95b973
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26ab692
 
a95b973
26ab692
 
a95b973
26ab692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a95b973
26ab692
 
 
 
 
 
 
 
 
 
 
 
 
a95b973
26ab692
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import logging 

logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
import streamlit as st
import streamlit.components.v1 as components
from myocr.pipelines import CommonOCRPipeline
from PIL import Image, ImageDraw, ImageFont
import tempfile
import pandas as pd
import os

logger = logging.getLogger(__name__)

st.set_page_config(layout="wide")
st.title("MyOCR Demo")

# init pipeline
@st.cache_resource
def load_pipeline():
    pipeline = CommonOCRPipeline("cpu")
    return pipeline

pipeline = load_pipeline()
font = ImageFont.truetype("src/NotoSans.ttf", 12) 

def google_analytics():
    # Google Analytics
    GA_MEASUREMENT_ID = os.getenv("GOOGLE_ANALYTICS") or "NO_ID"
    google_analytics_code = f"""
    <!-- Google tag (gtag.js) -->
    <script async src="https://www.googletagmanager.com/gtag/js?id={GA_MEASUREMENT_ID}"></script>
    <script>
      window.dataLayer = window.dataLayer || [];
      function gtag(){{dataLayer.push(arguments);}}
      gtag('js', new Date());
      gtag('config', '{GA_MEASUREMENT_ID}');
    </script>
    """
    components.html(google_analytics_code, height=0, width=0)

def main():
    google_analytics()
    left_col, right_col = st.columns([2, 1.5])
    with left_col:
        uploaded_file = st.file_uploader("", type=["png", "jpg", "jpeg"])
        spinner_container = st.empty()
        image_slot = st.empty()
        if uploaded_file:
            image = Image.open(uploaded_file)
            image_slot.image(image, use_container_width=True)
        else:
            return
        
        with tempfile.NamedTemporaryFile(delete=True, suffix=uploaded_file.name) as temp_file:
            temp_file.write(uploaded_file.getbuffer())
            temp_file_path = temp_file.name
            with spinner_container:
                with st.spinner("Recognizing text..."):
                    results = pipeline(temp_file_path)
        
        image = Image.open(uploaded_file).convert("RGB")
        image_with_boxes = image.copy()
        draw = ImageDraw.Draw(image_with_boxes)
        table_data = []
        
        for item in results.text_items:
            text = item.text
            bbox = item.bounding_box
            top_left = (bbox.left,bbox.top)
            bottom_right = (bbox.right, bbox.bottom )
            draw.rectangle([top_left, bottom_right], outline="red", width=2)
            draw.text((bbox.left, 0 if bbox.top - 18<0 else bbox.top - 18), text, font=font, fill="green")
            table_data.append((text, item.confidence))
        image_slot.image(image_with_boxes, use_container_width=True)
    
    with right_col:
        tabs = st.tabs(["Recognized"])
        with tabs[0]:
            df = pd.DataFrame(table_data, columns=["Text", "Confidence"])
            st.table(df)

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        logger.error("Error: ", e)
        st.error(f"Internal Error!")