Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
from PIL import Image | |
import requests | |
import gradio as gr | |
import pandas as pd | |
import subprocess | |
import os | |
# Install flash-attn without CUDA build | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
# Load the model and processor | |
model_id = "yifeihu/TB-OCR-preview-0.1" | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="cuda", | |
trust_remote_code=True, | |
torch_dtype="auto", | |
attn_implementation='flash_attention_2', | |
load_in_4bit=True | |
) | |
processor = AutoProcessor.from_pretrained(model_id, | |
trust_remote_code=True, | |
num_crops=16 | |
) | |
# Define the OCR function | |
def phi_ocr(image): | |
question = "Convert the text to markdown format." | |
prompt_message = [{ | |
'role': 'user', | |
'content': f'<|image_1|>\n{question}', | |
}] | |
prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True) | |
inputs = processor(prompt, [image], return_tensors="pt").to("cuda") | |
generation_args = { | |
"max_new_tokens": 1024, | |
"temperature": 0.1, | |
"do_sample": False | |
} | |
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args) | |
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] | |
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
response = response.split("<image_end>")[0] | |
return response | |
# Define the function to process multiple images and save results to a CSV | |
def process_images(input_images): | |
results = [] | |
for index, image in enumerate(input_images): | |
extracted_text = phi_ocr(image) | |
results.append({ | |
'index': index, | |
'extracted_text': extracted_text | |
}) | |
# Convert to DataFrame and save to CSV | |
df = pd.DataFrame(results) | |
output_csv = "extracted_entities.csv" | |
df.to_csv(output_csv, index=False) | |
return f"Processed {len(input_images)} images and saved to {output_csv}", output_csv | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# OCR with TB-OCR-preview-0.1") | |
gr.Markdown("Upload multiple images to extract and convert text to markdown format.") | |
gr.Markdown("[Check out ](https://huggingface.co/yifeihu/TB-OCR-preview-0.1)") | |
with gr.Row(): | |
input_images = gr.Image(type="pil", label="Upload Images", tool="editor", source="upload", multiple=True) | |
output_text = gr.Textbox(label="Status") | |
output_csv_link = gr.File(label="Download CSV") | |
input_images.change(fn=process_images, inputs=input_images, outputs=[output_text, output_csv_link]) | |
demo.launch() | |