Spaces:
Paused
Paused
import os | |
import tempfile | |
import gradio as gr | |
import numpy as np | |
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor | |
from qwen_vl_utils import process_vision_info | |
import torch | |
from ast import literal_eval | |
from PIL import Image | |
import json | |
# Load the model on the available device(s) | |
model = Qwen2VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto" | |
) | |
# Load the processor | |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | |
# Define your prompts | |
other_benifits = '''Extract the following information in the given format: | |
{'other_benefits_and_information': { | |
'401k eru: {'This Period':'', 'Year-to-Date':''}}, | |
'quota summary': | |
{ | |
'sick:': '', | |
'vacation:': '', | |
} | |
'payment method': 'eg. Direct payment', | |
'Amount': 'eg. 12.99' | |
} | |
''' | |
tax_deductions = '''Extract the following information in the given format: | |
{ | |
'tax_deductions': { | |
'federal:': { | |
'withholding tax:': {'Amount':'', 'Year-To_Date':""}, | |
'ee social security tax:': {'Amount':'', 'Year-To_Date':""}, | |
'ee medicare tax:': {'Amount':'', 'Year-To_Date':""}}, | |
'california:': { | |
'withholding tax:': {'Amount':'', 'Year-To_Date':""}, | |
'ee disability tax:': {'Amount':'', 'Year-To-Date':""}}}, | |
} | |
''' | |
intro = '''Extract the following information in the given format: | |
{ | |
'check date': "", | |
'name': "", | |
'address': "" | |
} | |
''' | |
def format_nested_dict(data, indent=0): | |
formatted_str = "" | |
indent_str = " " * indent # Indentation for the current level | |
for key, value in data.items(): | |
# If value is a dictionary, recurse deeper | |
if isinstance(value, dict): | |
formatted_str += f"{indent_str}{key}:\n" | |
formatted_str += format_nested_dict(value, indent + 1) | |
else: | |
formatted_str += f"{indent_str}{key}: {value}\n" | |
return formatted_str | |
def process_function(image_path, prompt): | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "image", | |
"image": image_path, # Use the file path here | |
}, | |
{"type": "text", "text": prompt}, | |
], | |
} | |
] | |
# Preparation for inference | |
text = processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt", | |
) | |
inputs = inputs.to("cuda") | |
# Inference: Generation of the output | |
generated_ids = model.generate(**inputs, max_new_tokens=1500) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
] | |
output_text = processor.batch_decode( | |
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
) | |
# Handle output text to convert it into JSON | |
try: | |
almost_json = output_text[0].split('```\n')[-1].split('\n```')[0] | |
json = literal_eval(almost_json) | |
except: | |
try: | |
almost_json = output_text[0].split('```json\n')[-1].split('\n```')[0] | |
json = literal_eval(almost_json) | |
except: | |
json = output_text[0] | |
return json | |
def process_document(image): | |
# Save the uploaded image to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file: | |
image = Image.fromarray(image) # Convert NumPy array to PIL Image | |
image.save(tmp_file.name) # Save the image to the temporary file | |
image_path = tmp_file.name # Get the path of the saved file | |
# Process the image with your model | |
one = process_function(image_path, other_benifits) | |
two = process_function(image_path, tax_deductions) | |
three = process_function(image_path, intro) | |
text_one = format_nested_dict(one) | |
text_two = format_nested_dict(two) | |
text_three = format_nested_dict(three) | |
# Optionally, you can delete the temporary file after use | |
os.remove(image_path) | |
return text_one, text_two, text_three, one, two, three | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=process_document, | |
inputs="image", # Gradio will handle the image input | |
outputs=[ | |
gr.Textbox(label="Other Benefits and Information"), # Second output box with heading | |
gr.Textbox(label="Tax Deductions Information"), # Second output box with heading | |
gr.Textbox(label="Introduction"), # Second output box with heading | |
gr.JSON(label="Other Benefits and Information"), # First output box with heading | |
gr.JSON(label="Tax Deductions Information"), # First output box with heading | |
gr.JSON(label="Introduction"), | |
], | |
title="<div style='text-align: center;'>Information Extraction From PaySlip</div>", | |
examples=[["Slip_1.jpg"], ["Slip_2.jpg"]], | |
cache_examples=False | |
) | |
demo.launch() | |