pranshh commited on
Commit
d04279c
1 Parent(s): c265e46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -32
app.py CHANGED
@@ -1,56 +1,79 @@
1
- from transformers import AutoProcessor
 
2
  import torch
3
  import gradio as gr
4
  from PIL import Image
5
- from byaldi import RAGMultiModalModel
6
- from qwen_vl_utils import process_vision_info
7
- import os
8
- import tempfile
9
 
10
- # Load ColPali model
11
- RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
12
 
13
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
14
 
 
 
15
  def load_model():
16
- return RAG.model
 
 
 
 
17
 
 
18
  vlm = load_model()
19
 
20
- def ocr_image(image, keyword=""):
21
- # Save the image to a temporary file
22
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
23
- image.save(temp_file, format='PNG')
24
- temp_file_path = temp_file.name
25
-
26
- try:
27
- # Index the image
28
- RAG.index(input_path=temp_file_path, index_name="temp_index", overwrite=True)
 
 
 
 
 
29
 
30
- # Retrieve text from the image
31
- results = RAG.search("Extract all text from this image", k=1)
32
-
33
- output_text = results[0].text if results else ''
 
 
 
 
 
 
 
34
 
35
- if keyword:
36
- keyword_lower = keyword.lower()
37
- if keyword_lower in output_text.lower():
38
- highlighted_text = output_text.replace(keyword, f"**{keyword}**")
39
- return f"Keyword '{keyword}' found in the text:\n\n{highlighted_text}"
40
- else:
41
- return f"Keyword '{keyword}' not found in the text:\n\n{output_text}"
 
 
 
 
 
 
 
 
42
  else:
43
- return output_text
44
- finally:
45
- # Clean up the temporary file
46
- os.unlink(temp_file_path)
47
 
 
48
  def process_image(image, keyword=""):
49
  max_size = 1024
50
  if max(image.size) > max_size:
51
  image.thumbnail((max_size, max_size))
52
  return ocr_image(image, keyword=keyword)
53
 
 
54
  interface = gr.Interface(
55
  fn=process_image,
56
  inputs=[
@@ -61,4 +84,5 @@ interface = gr.Interface(
61
  title="Hindi & English OCR with Keyword Search",
62
  )
63
 
 
64
  interface.launch()
 
1
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
2
+ from qwen_vl_utils import process_vision_info
3
  import torch
4
  import gradio as gr
5
  from PIL import Image
 
 
 
 
6
 
 
 
7
 
8
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
9
 
10
+ # Initialize the model with float16 precision and handle fallback to CPU
11
+ # Simplified model loading function for CPU
12
  def load_model():
13
+ return Qwen2VLForConditionalGeneration.from_pretrained(
14
+ "Qwen/Qwen2-VL-2B-Instruct",
15
+ torch_dtype=torch.float32, # Use float32 for CPU
16
+ low_cpu_mem_usage=True
17
+ )
18
 
19
+ # Load the model
20
  vlm = load_model()
21
 
22
+ # OCR function to extract text from an image
23
+ def ocr_image(image, query="Extract text from the image", keyword=""):
24
+ messages = [
25
+ {
26
+ "role": "user",
27
+ "content": [
28
+ {
29
+ "type": "image",
30
+ "image": image,
31
+ },
32
+ {"type": "text", "text": query},
33
+ ],
34
+ }
35
+ ]
36
 
37
+ # Prepare inputs for the model
38
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
39
+ image_inputs, video_inputs = process_vision_info(messages)
40
+ inputs = processor(
41
+ text=[text],
42
+ images=image_inputs,
43
+ videos=video_inputs,
44
+ padding=True,
45
+ return_tensors="pt",
46
+ )
47
+ inputs = inputs.to("cpu")
48
 
49
+ # Generate the output text using the model
50
+ generated_ids = vlm.generate(**inputs, max_new_tokens=512)
51
+ generated_ids_trimmed = [
52
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
53
+ ]
54
+
55
+ output_text = processor.batch_decode(
56
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
57
+ )[0]
58
+
59
+ if keyword:
60
+ keyword_lower = keyword.lower()
61
+ if keyword_lower in output_text.lower():
62
+ highlighted_text = output_text.replace(keyword, f"**{keyword}**")
63
+ return f"Keyword '{keyword}' found in the text:\n\n{highlighted_text}"
64
  else:
65
+ return f"Keyword '{keyword}' not found in the text:\n\n{output_text}"
66
+ else:
67
+ return output_text
 
68
 
69
+ # Gradio interface
70
  def process_image(image, keyword=""):
71
  max_size = 1024
72
  if max(image.size) > max_size:
73
  image.thumbnail((max_size, max_size))
74
  return ocr_image(image, keyword=keyword)
75
 
76
+ # Update the Gradio interface:
77
  interface = gr.Interface(
78
  fn=process_image,
79
  inputs=[
 
84
  title="Hindi & English OCR with Keyword Search",
85
  )
86
 
87
+ # Launch Gradio interface in Colab
88
  interface.launch()