OCR-App-to-test / app.py
Anuji's picture
Update app.py
f5ef54e verified
import torch
import streamlit as st
from PIL import Image
from deepseek_vl2.serve.inference import load_model, deepseek_generate, convert_conversation_to_prompts
from deepseek_vl2.serve.app_modules.utils import configure_logger, strip_stop_words, pil_to_base64
# Set Page Config (Must be first!)
st.set_page_config(layout="wide")
# Set up logging
logger = configure_logger()
# Models and deployment
MODELS = ["deepseek-ai/deepseek-vl2-tiny"]
DEPLOY_MODELS = {}
IMAGE_TOKEN = "<image>"
# Fetch model
def fetch_model(model_name: str, dtype=torch.bfloat16):
global DEPLOY_MODELS
if model_name not in DEPLOY_MODELS:
logger.info(f"Loading {model_name}...")
model_info = load_model(model_name, dtype=dtype)
tokenizer, model, vl_chat_processor = model_info
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device).eval() # Move to appropriate device
DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor)
logger.info(f"Loaded {model_name} on {device}")
return DEPLOY_MODELS[model_name]
# Generate prompt with history
def generate_prompt_with_history(text, images, history, vl_chat_processor, tokenizer, max_length=2048):
conversation = vl_chat_processor.new_chat_template()
if history:
conversation.messages = history
if images:
text = f"{IMAGE_TOKEN}\n{text}"
text = (text, images)
conversation.append_message(conversation.roles[0], text)
conversation.append_message(conversation.roles[1], "")
return conversation
# Convert conversation to gradio format
def to_gradio_chatbot(conv):
ret = []
for i, (role, msg) in enumerate(conv.messages[conv.offset:]):
if i % 2 == 0:
if isinstance(msg, tuple):
msg, images = msg
for image in images:
img_b64 = pil_to_base64(image, "user upload", max_size=800, min_size=400)
msg = msg.replace(IMAGE_TOKEN, img_b64, 1)
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
# Predict function (simplified for OCR)
def predict(text, images, model_name="deepseek-ai/deepseek-vl2-tiny"):
logger.info("Starting predict function...")
tokenizer, vl_gpt, vl_chat_processor = fetch_model(model_name)
if not text:
logger.warning("Empty text input detected.")
return "Empty context."
logger.info("Processing images...")
pil_images = [Image.open(img).convert("RGB") for img in images] if images else []
conversation = generate_prompt_with_history(
text, pil_images, [], vl_chat_processor, tokenizer
)
all_conv, _ = convert_conversation_to_prompts(conversation)
stop_words = conversation.stop_str
full_response = ""
logger.info("Generating response...")
try:
with torch.no_grad():
for x in deepseek_generate(
conversations=all_conv,
vl_gpt=vl_gpt,
vl_chat_processor=vl_chat_processor,
tokenizer=tokenizer,
stop_words=stop_words,
max_length=2048,
temperature=0.1,
top_p=0.9,
repetition_penalty=1.1
):
full_response += x
response = strip_stop_words(full_response, stop_words)
logger.info("Generation complete.")
return response
except Exception as e:
logger.error(f"Error in generation: {str(e)}")
return f"Error: {str(e)}"
# OCR extraction function
def extract_text(image):
if image is None:
return "Please upload an image."
prompt = "Extract all text from this image exactly as it appears, ensuring the output is in English only. Preserve spaces, bullets, numbers, and all formatting. Do not translate, generate, or include text in any other language. Stop at the last character of the image text."
logger.info("Starting text extraction...")
extracted_text = predict(prompt, [image])
return extracted_text
# User Selection State
if "user_selected" not in st.session_state:
st.session_state["user_selected"] = False
if "selected_user" not in st.session_state:
st.session_state["selected_user"] = None
# User Selection Page
@st.cache_data
def get_user_names():
return ["DS", "NW", "RB", "IG", "AR", "NU"]
user_names = get_user_names()
# Retrieve cached user from query params
query_params = st.query_params
cached_user = query_params.get("user", None)
if cached_user and "user_selected" not in st.session_state:
st.session_state["user_selected"] = True
st.session_state["selected_user"] = cached_user
if not st.session_state["user_selected"]:
st.title("πŸ‘‹ Let’s Get Started! Identify Yourself to Begin")
selected_user = st.selectbox("Choose your name:", user_names, index=user_names.index(cached_user) if cached_user in user_names else None)
continue_button = st.button("Continue", disabled=not selected_user)
if continue_button:
st.session_state["user_selected"] = True
st.session_state["selected_user"] = selected_user
st.query_params["user"] = selected_user
st.rerun()
if not selected_user:
st.warning("⚠ Please select a user to continue.")
else:
st.write(f"βœ… Welcome Back, {st.session_state['selected_user']}!")
# Main UI (Only loads after user selection)
if st.session_state["user_selected"]:
st.markdown("<h1 style='text-align: center;'>πŸ” Extract Job Info in One Click</h1>", unsafe_allow_html=True)
uploaded_file = st.file_uploader("Upload an Image (PNG, JPG, JPEG)", type=["png", "jpg", "jpeg"])
if uploaded_file:
img = Image.open(uploaded_file)
extracted_text = extract_text(uploaded_file)
col1, col2 = st.columns(2)
with col1:
st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
with col2:
st.markdown(
f"""
<div style="border: 1px solid #ccc; padding: 10px; width: 100%; white-space: pre-wrap; overflow: hidden; text-align: left;">
{extracted_text}
</div>
""",
unsafe_allow_html=True
)
st.session_state["extracted_text"] = extracted_text
errors_text = st.text_area("Paste Any Errors Here", height=200)
rating = st.radio("Rate the OCR Extraction (5-1)", options=[5, 4, 3, 2, 1], index=None, horizontal=True)
ref_number = st.text_input("Enter Reference Number", max_chars=10)
if ref_number and not ref_number.isdigit():
st.warning("⚠ Reference Number must be a number.")
ref_number = ""
if not ref_number or rating is None:
st.warning("⚠ Please enter a Reference Number and select a Rating to proceed.")
st.button("Submit", disabled=True)
else:
if st.button("Submit"):
st.success("βœ… Submitted successfully!")
# No Google Drive or Sheets upload; just a confirmation