Spaces:
Sleeping
Sleeping
import torch | |
import streamlit as st | |
from PIL import Image | |
from deepseek_vl2.serve.inference import load_model, deepseek_generate, convert_conversation_to_prompts | |
from deepseek_vl2.serve.app_modules.utils import configure_logger, strip_stop_words, pil_to_base64 | |
# Set Page Config (Must be first!) | |
st.set_page_config(layout="wide") | |
# Set up logging | |
logger = configure_logger() | |
# Models and deployment | |
MODELS = ["deepseek-ai/deepseek-vl2-tiny"] | |
DEPLOY_MODELS = {} | |
IMAGE_TOKEN = "<image>" | |
# Fetch model | |
def fetch_model(model_name: str, dtype=torch.bfloat16): | |
global DEPLOY_MODELS | |
if model_name not in DEPLOY_MODELS: | |
logger.info(f"Loading {model_name}...") | |
model_info = load_model(model_name, dtype=dtype) | |
tokenizer, model, vl_chat_processor = model_info | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = model.to(device).eval() # Move to appropriate device | |
DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor) | |
logger.info(f"Loaded {model_name} on {device}") | |
return DEPLOY_MODELS[model_name] | |
# Generate prompt with history | |
def generate_prompt_with_history(text, images, history, vl_chat_processor, tokenizer, max_length=2048): | |
conversation = vl_chat_processor.new_chat_template() | |
if history: | |
conversation.messages = history | |
if images: | |
text = f"{IMAGE_TOKEN}\n{text}" | |
text = (text, images) | |
conversation.append_message(conversation.roles[0], text) | |
conversation.append_message(conversation.roles[1], "") | |
return conversation | |
# Convert conversation to gradio format | |
def to_gradio_chatbot(conv): | |
ret = [] | |
for i, (role, msg) in enumerate(conv.messages[conv.offset:]): | |
if i % 2 == 0: | |
if isinstance(msg, tuple): | |
msg, images = msg | |
for image in images: | |
img_b64 = pil_to_base64(image, "user upload", max_size=800, min_size=400) | |
msg = msg.replace(IMAGE_TOKEN, img_b64, 1) | |
ret.append([msg, None]) | |
else: | |
ret[-1][-1] = msg | |
return ret | |
# Predict function (simplified for OCR) | |
def predict(text, images, model_name="deepseek-ai/deepseek-vl2-tiny"): | |
logger.info("Starting predict function...") | |
tokenizer, vl_gpt, vl_chat_processor = fetch_model(model_name) | |
if not text: | |
logger.warning("Empty text input detected.") | |
return "Empty context." | |
logger.info("Processing images...") | |
pil_images = [Image.open(img).convert("RGB") for img in images] if images else [] | |
conversation = generate_prompt_with_history( | |
text, pil_images, [], vl_chat_processor, tokenizer | |
) | |
all_conv, _ = convert_conversation_to_prompts(conversation) | |
stop_words = conversation.stop_str | |
full_response = "" | |
logger.info("Generating response...") | |
try: | |
with torch.no_grad(): | |
for x in deepseek_generate( | |
conversations=all_conv, | |
vl_gpt=vl_gpt, | |
vl_chat_processor=vl_chat_processor, | |
tokenizer=tokenizer, | |
stop_words=stop_words, | |
max_length=2048, | |
temperature=0.1, | |
top_p=0.9, | |
repetition_penalty=1.1 | |
): | |
full_response += x | |
response = strip_stop_words(full_response, stop_words) | |
logger.info("Generation complete.") | |
return response | |
except Exception as e: | |
logger.error(f"Error in generation: {str(e)}") | |
return f"Error: {str(e)}" | |
# OCR extraction function | |
def extract_text(image): | |
if image is None: | |
return "Please upload an image." | |
prompt = "Extract all text from this image exactly as it appears, ensuring the output is in English only. Preserve spaces, bullets, numbers, and all formatting. Do not translate, generate, or include text in any other language. Stop at the last character of the image text." | |
logger.info("Starting text extraction...") | |
extracted_text = predict(prompt, [image]) | |
return extracted_text | |
# User Selection State | |
if "user_selected" not in st.session_state: | |
st.session_state["user_selected"] = False | |
if "selected_user" not in st.session_state: | |
st.session_state["selected_user"] = None | |
# User Selection Page | |
def get_user_names(): | |
return ["DS", "NW", "RB", "IG", "AR", "NU"] | |
user_names = get_user_names() | |
# Retrieve cached user from query params | |
query_params = st.query_params | |
cached_user = query_params.get("user", None) | |
if cached_user and "user_selected" not in st.session_state: | |
st.session_state["user_selected"] = True | |
st.session_state["selected_user"] = cached_user | |
if not st.session_state["user_selected"]: | |
st.title("π Letβs Get Started! Identify Yourself to Begin") | |
selected_user = st.selectbox("Choose your name:", user_names, index=user_names.index(cached_user) if cached_user in user_names else None) | |
continue_button = st.button("Continue", disabled=not selected_user) | |
if continue_button: | |
st.session_state["user_selected"] = True | |
st.session_state["selected_user"] = selected_user | |
st.query_params["user"] = selected_user | |
st.rerun() | |
if not selected_user: | |
st.warning("β Please select a user to continue.") | |
else: | |
st.write(f"β Welcome Back, {st.session_state['selected_user']}!") | |
# Main UI (Only loads after user selection) | |
if st.session_state["user_selected"]: | |
st.markdown("<h1 style='text-align: center;'>π Extract Job Info in One Click</h1>", unsafe_allow_html=True) | |
uploaded_file = st.file_uploader("Upload an Image (PNG, JPG, JPEG)", type=["png", "jpg", "jpeg"]) | |
if uploaded_file: | |
img = Image.open(uploaded_file) | |
extracted_text = extract_text(uploaded_file) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(uploaded_file, caption="Uploaded Image", use_container_width=True) | |
with col2: | |
st.markdown( | |
f""" | |
<div style="border: 1px solid #ccc; padding: 10px; width: 100%; white-space: pre-wrap; overflow: hidden; text-align: left;"> | |
{extracted_text} | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
st.session_state["extracted_text"] = extracted_text | |
errors_text = st.text_area("Paste Any Errors Here", height=200) | |
rating = st.radio("Rate the OCR Extraction (5-1)", options=[5, 4, 3, 2, 1], index=None, horizontal=True) | |
ref_number = st.text_input("Enter Reference Number", max_chars=10) | |
if ref_number and not ref_number.isdigit(): | |
st.warning("β Reference Number must be a number.") | |
ref_number = "" | |
if not ref_number or rating is None: | |
st.warning("β Please enter a Reference Number and select a Rating to proceed.") | |
st.button("Submit", disabled=True) | |
else: | |
if st.button("Submit"): | |
st.success("β Submitted successfully!") | |
# No Google Drive or Sheets upload; just a confirmation |