|
|
|
|
|
|
|
|
|
import gradio as gr |
|
import os |
|
import torch |
|
from transformers import AutoProcessor, MllamaForConditionalGeneration, TextStreamer |
|
from PIL import Image |
|
import csv |
|
import spaces |
|
|
|
IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1" |
|
IS_SPACE = os.environ.get("SPACE_ID", None) is not None |
|
IS_GDRVIE = False |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1" |
|
print(f"Using device: {device}") |
|
print(f"Low memory mode: {LOW_MEMORY}") |
|
|
|
|
|
HF_TOKEN = os.environ.get('HF_TOKEN') |
|
|
|
|
|
model_name = "Llama-3.2-11B-Vision-Instruct" |
|
if IS_GDRVIE: |
|
|
|
model_path = "/content/drive/MyDrive/models/" + model_name |
|
model = MllamaForConditionalGeneration.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
) |
|
processor = AutoProcessor.from_pretrained(model_path) |
|
else: |
|
|
|
HF_TOKEN = os.environ.get('HF_TOKEN') |
|
|
|
|
|
model_name = "ruslanmv/Llama-3.2-11B-Vision-Instruct" |
|
model = MllamaForConditionalGeneration.from_pretrained( |
|
model_name, |
|
use_auth_token=HF_TOKEN, |
|
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32, |
|
device_map="auto" if device == "cuda" else None, |
|
) |
|
|
|
|
|
model.to(device) |
|
processor = AutoProcessor.from_pretrained(model_name, use_auth_token=HF_TOKEN) |
|
|
|
|
|
|
|
if hasattr(model, "tie_weights"): |
|
model.tie_weights() |
|
|
|
example = '''Table 1: |
|
header1,header2,header3 |
|
value1,value2,value3 |
|
|
|
Table 2: |
|
header1,header2,header3 |
|
value1,value2,value3 |
|
''' |
|
|
|
prompt_message = """Please extract all tables from the image and generate CSV files. |
|
Each table should be separated using the format table_n.csv, where n is the table number. |
|
You must use CSV format with commas as the delimiter. Do not use markdown format. Ensure you use the original table headers and content from the image. |
|
Only answer with the CSV content. Dont explain the tables. |
|
An example of the formatting output is as follows: |
|
""" + example |
|
|
|
|
|
|
|
def stream_response(inputs): |
|
streamer = TextStreamer(tokenizer=processor.tokenizer) |
|
for token in model.generate(**inputs, max_new_tokens=2000, do_sample=True, streamer=streamer): |
|
yield processor.decode(token, skip_special_tokens=True) |
|
|
|
|
|
@spaces.GPU |
|
|
|
def predict(message, image): |
|
|
|
messages = [ |
|
{"role": "user", "content": [ |
|
{"type": "image"}, |
|
{"type": "text", "text": message} |
|
]} |
|
] |
|
|
|
|
|
input_text = processor.apply_chat_template(messages, add_generation_prompt=True) |
|
|
|
|
|
inputs = processor(image, input_text, return_tensors="pt").to(device) |
|
|
|
|
|
full_response = "" |
|
for response in stream_response(inputs): |
|
|
|
full_response += response |
|
return extract_and_save_tables(full_response) |
|
|
|
|
|
files_list = [] |
|
|
|
def clean_full_response(full_response): |
|
"""Cleans the full response by removing the prompt input before the tables.""" |
|
|
|
message_to_remove = prompt_message |
|
|
|
return full_response.replace(message_to_remove, "").strip() |
|
|
|
def extract_and_save_tables(full_response): |
|
"""Extracts CSV tables from the cleaned_response string and saves them as separate files.""" |
|
cleaned_response = clean_full_response(full_response) |
|
files_list = [] |
|
tables = cleaned_response.split("Table ") |
|
|
|
for i, table in enumerate(tables[1:], start=1): |
|
table_name = f"table_{i}.csv" |
|
rows = table.strip().splitlines()[1:] |
|
rows = [row.replace('"', '').split(",") for row in rows if row.strip()] |
|
|
|
|
|
with open(table_name, mode="w", newline='') as file: |
|
writer = csv.writer(file) |
|
writer.writerows(rows) |
|
|
|
files_list.append(table_name) |
|
|
|
return files_list |
|
|
|
|
|
|
|
def gradio_app(): |
|
def process_image(image): |
|
message = prompt_message |
|
files = predict(message, image) |
|
return "Tables extracted and saved as CSV files.", files |
|
|
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
|
|
|
|
output_text = gr.Textbox(label="Extraction Status") |
|
file_output = gr.File(label="Download CSV files") |
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_image, |
|
inputs=[image_input], |
|
outputs=[output_text, file_output], |
|
title="Table Extractor and CSV Converter", |
|
description="Upload an image to extract tables and download CSV files.", |
|
allow_flagging="never" |
|
) |
|
|
|
iface.launch(debug=True) |
|
|
|
|
|
|
|
gradio_app() |