pranshh commited on
Commit
d2212a0
·
verified ·
1 Parent(s): 45f2b08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -142
app.py CHANGED
@@ -7,151 +7,76 @@ Original file is located at
7
  https://colab.research.google.com/drive/1vzsQ17-W1Vy6yJ60XUwFy0QRkOR_SIg7
8
  """
9
 
 
 
 
10
  import gradio as gr
11
- from transformers import AutoModel, AutoTokenizer
12
  from PIL import Image
13
- import os
14
 
15
- revision = "5364fe1ab774ef13c2c79023dc91d8c1e7cfdce4"
16
 
17
- # Load tokenizer and model
18
- tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
19
- model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
20
- model = model.eval()
21
-
22
- # Function to perform OCR and optional keyword search
23
- def process_image_with_search(image, keyword):
24
- try:
25
- # Save the PIL image to a temporary file
26
- temp_img_path = "temp_image.png"
27
- image.save(temp_img_path)
28
-
29
- # Perform OCR with the model using the file path
30
- extracted_text = model.chat(tokenizer, temp_img_path, ocr_type='format')
31
-
32
- # Delete the temporary file
33
- if os.path.exists(temp_img_path):
34
- os.remove(temp_img_path)
35
-
36
- # Convert extracted text to string if it's not already
37
- extracted_text = extracted_text if isinstance(extracted_text, str) else str(extracted_text)
38
-
39
- # If a keyword is provided, search for it
40
- if keyword:
41
- # Perform keyword search (case-insensitive)
42
- if keyword.lower() in extracted_text.lower():
43
- # Highlight the keyword in the extracted text
44
- highlighted_text = extracted_text.replace(keyword, f"**{keyword}**", 1) # Highlight first occurrence
45
- result = f"Keyword '{keyword}' found:\n\n{highlighted_text}"
46
- else:
47
- result = f"Keyword '{keyword}' not found in the extracted text.\n\nExtracted Text:\n{extracted_text}"
48
- else:
49
- # If no keyword is provided, return the extracted text without searching
50
- result = f"Extracted Text:\n\n{extracted_text}"
51
-
52
- return result
53
- except Exception as e:
54
- return str(e) # Return error message in case of failure
55
-
56
- # Define Gradio interface
57
- iface = gr.Interface(
58
- fn=process_image_with_search, # The function to process the image and search keyword
59
- inputs=[gr.Image(type='pil'), gr.Textbox(label="Enter keyword to search (optional)")], # Image input + Keyword input
60
- outputs='text', # Output will be plain text with the search result
61
- title="OCR with GOT and Keyword Search",
62
- description="Upload an image to get OCR results. You can also search for a keyword in the extracted text."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
 
65
- # Launch the interface
66
- iface.launch(debug=True)
67
-
68
- # !pip install --upgrade git+https://github.com/huggingface/transformers.git byaldi accelerate flash-attn qwen_vl_utils pdf2image gradio
69
- # !sudo apt-get install -y poppler-utils
70
-
71
- # from byaldi import RAGMultiModalModel
72
- # from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
73
- # from qwen_vl_utils import process_vision_info
74
- # import torch
75
- # import gradio as gr
76
- # from PIL import Image
77
-
78
- # processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
79
-
80
- # # Initialize the model with float16 precision and handle fallback to CPU
81
- # def load_model():
82
- # try:
83
- # vlm = Qwen2VLForConditionalGeneration.from_pretrained(
84
- # "Qwen/Qwen2-VL-2B-Instruct",
85
- # torch_dtype=torch.float16,
86
- # attn_implementation="flash_attention_2", # FlashAttention enabled
87
- # device_map="cuda"
88
- # )
89
- # print("Model loaded with FlashAttention on GPU")
90
- # except RuntimeError as e:
91
- # if "FlashAttention only supports Ampere GPUs" in str(e):
92
- # print("FlashAttention not supported. Falling back to standard attention.")
93
- # vlm = Qwen2VLForConditionalGeneration.from_pretrained(
94
- # "Qwen/Qwen2-VL-2B-Instruct",
95
- # torch_dtype=torch.float16, # Still use float16 to save memory
96
- # attn_implementation="default", # Use standard attention mechanism
97
- # device_map="cuda" if torch.cuda.is_available() else "cpu"
98
- # )
99
- # else:
100
- # raise e # Raise other runtime errors if not related to FlashAttention
101
- # return vlm
102
-
103
- # # Load the model
104
- # vlm = load_model()
105
-
106
- # # OCR function to extract text from an image
107
- # def ocr_image(image, query="Extract text from the image"):
108
- # messages = [
109
- # {
110
- # "role": "user",
111
- # "content": [
112
- # {
113
- # "type": "image",
114
- # "image": image,
115
- # },
116
- # {"type": "text", "text": query},
117
- # ],
118
- # }
119
- # ]
120
-
121
- # # Prepare inputs for the model
122
- # text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
123
- # image_inputs, video_inputs = process_vision_info(messages)
124
- # inputs = processor(
125
- # text=[text],
126
- # images=image_inputs,
127
- # videos=video_inputs,
128
- # padding=True,
129
- # return_tensors="pt",
130
- # )
131
- # inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
132
-
133
- # # Generate the output text using the model
134
- # generated_ids = vlm.generate(**inputs, max_new_tokens=512)
135
- # generated_ids_trimmed = [
136
- # out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
137
- # ]
138
- # output_text = processor.batch_decode(
139
- # generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
140
- # )
141
- # return output_text[0]
142
-
143
- # # Gradio interface
144
- # def process_image(image):
145
- # return ocr_image(image)
146
-
147
- # # Create Gradio interface for uploading an image
148
- # interface = gr.Interface(
149
- # fn=process_image,
150
- # inputs=gr.Image(type="pil"),
151
- # outputs="text",
152
- # title="Hindi & English OCR",
153
- # description="Upload an image containing text in Hindi or English to extract the text using OCR."
154
- # )
155
-
156
- # # Launch Gradio interface in Colab
157
- # interface.launch(share=True, debug=True)
 
7
  https://colab.research.google.com/drive/1vzsQ17-W1Vy6yJ60XUwFy0QRkOR_SIg7
8
  """
9
 
10
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
11
+ from qwen_vl_utils import process_vision_info
12
+ import torch
13
  import gradio as gr
 
14
  from PIL import Image
 
15
 
 
16
 
17
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
18
+
19
+ # Initialize the model with float16 precision and handle fallback to CPU
20
+ # Simplified model loading function for CPU
21
+ def load_model():
22
+ return Qwen2VLForConditionalGeneration.from_pretrained(
23
+ "Qwen/Qwen2-VL-2B-Instruct",
24
+ torch_dtype=torch.float32, # Use float32 for CPU
25
+ device_map="cpu"
26
+ )
27
+
28
+ # Load the model
29
+ vlm = load_model()
30
+
31
+ # OCR function to extract text from an image
32
+ def ocr_image(image, query="Extract text from the image"):
33
+ messages = [
34
+ {
35
+ "role": "user",
36
+ "content": [
37
+ {
38
+ "type": "image",
39
+ "image": image,
40
+ },
41
+ {"type": "text", "text": query},
42
+ ],
43
+ }
44
+ ]
45
+
46
+ # Prepare inputs for the model
47
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
48
+ image_inputs, video_inputs = process_vision_info(messages)
49
+ inputs = processor(
50
+ text=[text],
51
+ images=image_inputs,
52
+ videos=video_inputs,
53
+ padding=True,
54
+ return_tensors="pt",
55
+ )
56
+ inputs = inputs.to("cpu")
57
+
58
+ # Generate the output text using the model
59
+ generated_ids = vlm.generate(**inputs, max_new_tokens=512)
60
+ generated_ids_trimmed = [
61
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
62
+ ]
63
+ output_text = processor.batch_decode(
64
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
65
+ )
66
+ return output_text[0]
67
+
68
+ # Gradio interface
69
+ def process_image(image):
70
+ return ocr_image(image)
71
+
72
+ # Create Gradio interface for uploading an image
73
+ interface = gr.Interface(
74
+ fn=process_image,
75
+ inputs=gr.Image(type="pil"),
76
+ outputs="text",
77
+ title="Hindi & English OCR",
78
+ description="Upload an image containing text in Hindi or English to extract the text using OCR."
79
  )
80
 
81
+ # Launch Gradio interface in Colab
82
+ interface.launch()