Yash Malviya commited on
Commit
ac8933f
1 Parent(s): be56acc

added app.py

Browse files
Files changed (2) hide show
  1. app.py +80 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoProcessor
3
+ from PIL import Image
4
+ import requests
5
+ import gradio as gr
6
+ import pandas as pd
7
+ import subprocess
8
+ import os
9
+
10
+ # Install flash-attn without CUDA build
11
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
+
13
+ # Load the model and processor
14
+ model_id = "yifeihu/TB-OCR-preview-0.1"
15
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_id,
19
+ device_map="cuda",
20
+ trust_remote_code=True,
21
+ torch_dtype="auto",
22
+ attn_implementation='flash_attention_2',
23
+ load_in_4bit=True
24
+ )
25
+ processor = AutoProcessor.from_pretrained(model_id,
26
+ trust_remote_code=True,
27
+ num_crops=16
28
+ )
29
+
30
+ # Define the OCR function
31
+ def phi_ocr(image):
32
+ question = "Convert the text to markdown format."
33
+ prompt_message = [{
34
+ 'role': 'user',
35
+ 'content': f'<|image_1|>\n{question}',
36
+ }]
37
+ prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True)
38
+ inputs = processor(prompt, [image], return_tensors="pt").to("cuda")
39
+ generation_args = {
40
+ "max_new_tokens": 1024,
41
+ "temperature": 0.1,
42
+ "do_sample": False
43
+ }
44
+ generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
45
+ generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
46
+ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
47
+ response = response.split("<image_end>")[0]
48
+ return response
49
+
50
+ # Define the function to process multiple images and save results to a CSV
51
+ def process_images(input_images):
52
+ results = []
53
+ for index, image in enumerate(input_images):
54
+ extracted_text = phi_ocr(image)
55
+ results.append({
56
+ 'index': index,
57
+ 'extracted_text': extracted_text
58
+ })
59
+
60
+ # Convert to DataFrame and save to CSV
61
+ df = pd.DataFrame(results)
62
+ output_csv = "extracted_entities.csv"
63
+ df.to_csv(output_csv, index=False)
64
+
65
+ return f"Processed {len(input_images)} images and saved to {output_csv}", output_csv
66
+
67
+ # Gradio UI
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("# OCR with TB-OCR-preview-0.1")
70
+ gr.Markdown("Upload multiple images to extract and convert text to markdown format.")
71
+ gr.Markdown("[Check out the model here](https://huggingface.co/yifeihu/TB-OCR-preview-0.1)")
72
+
73
+ with gr.Row():
74
+ input_images = gr.Image(type="pil", label="Upload Images", tool="editor", source="upload", multiple=True)
75
+ output_text = gr.Textbox(label="Status")
76
+ output_csv_link = gr.File(label="Download CSV")
77
+
78
+ input_images.change(fn=process_images, inputs=input_images, outputs=[output_text, output_csv_link])
79
+
80
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ torchvision
4
+ Pillow
5
+ pandas
6
+ gradio
7
+ flash-attn