pranshh commited on
Commit
4af6e9e
·
verified ·
1 Parent(s): 8222a16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -33
app.py CHANGED
@@ -7,64 +7,49 @@ Original file is located at
7
  https://colab.research.google.com/drive/1vzsQ17-W1Vy6yJ60XUwFy0QRkOR_SIg7
8
  """
9
 
10
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
11
- from qwen_vl_utils import process_vision_info
12
  import torch
13
  import gradio as gr
14
  from PIL import Image
15
-
 
 
16
 
17
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
 
18
 
19
- # Initialize the model with float16 precision and handle fallback to CPU
20
- # Simplified model loading function for CPU
21
  def load_model():
22
- return Qwen2VLForConditionalGeneration.from_pretrained(
23
  "Qwen/Qwen2-VL-2B-Instruct",
24
- torch_dtype=torch.float32, # Use float32 for CPU
25
- low_cpu_mem_usage=True
 
26
  )
27
 
28
- # Load the model
29
  vlm = load_model()
30
 
31
- # OCR function to extract text from an image
32
  def ocr_image(image, query="Extract text from the image", keyword=""):
 
 
33
  messages = [
34
  {
35
  "role": "user",
36
  "content": [
37
  {
38
  "type": "image",
39
- "image": image,
40
  },
41
  {"type": "text", "text": query},
42
  ],
43
  }
44
  ]
45
 
46
- # Prepare inputs for the model
47
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
48
- image_inputs, video_inputs = process_vision_info(messages)
49
- inputs = processor(
50
- text=[text],
51
- images=image_inputs,
52
- videos=video_inputs,
53
- padding=True,
54
- return_tensors="pt",
55
- )
56
  inputs = inputs.to("cpu")
57
 
58
- # Generate the output text using the model
59
- generated_ids = vlm.generate(**inputs, max_new_tokens=512)
60
- generated_ids_trimmed = [
61
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
62
- ]
63
-
64
- output_text = processor.batch_decode(
65
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
66
- )[0]
67
-
68
  if keyword:
69
  keyword_lower = keyword.lower()
70
  if keyword_lower in output_text.lower():
@@ -75,14 +60,14 @@ def ocr_image(image, query="Extract text from the image", keyword=""):
75
  else:
76
  return output_text
77
 
78
- # Gradio interface
79
  def process_image(image, keyword=""):
 
80
  max_size = 1024
81
  if max(image.size) > max_size:
82
  image.thumbnail((max_size, max_size))
83
  return ocr_image(image, keyword=keyword)
84
 
85
- # Update the Gradio interface:
86
  interface = gr.Interface(
87
  fn=process_image,
88
  inputs=[
 
7
  https://colab.research.google.com/drive/1vzsQ17-W1Vy6yJ60XUwFy0QRkOR_SIg7
8
  """
9
 
10
+ from transformers import AutoProcessor
 
11
  import torch
12
  import gradio as gr
13
  from PIL import Image
14
+ # Hypothetical imports
15
+ from byaldi import ByaldiProcessor
16
+ from colpali import ColPaliQwen2VLModel
17
 
18
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
19
+ byaldi_processor = ByaldiProcessor()
20
 
 
 
21
  def load_model():
22
+ return ColPaliQwen2VLModel.from_pretrained(
23
  "Qwen/Qwen2-VL-2B-Instruct",
24
+ torch_dtype=torch.float32,
25
+ low_cpu_mem_usage=True,
26
+ device_map="auto"
27
  )
28
 
 
29
  vlm = load_model()
30
 
 
31
  def ocr_image(image, query="Extract text from the image", keyword=""):
32
+ processed_image = byaldi_processor.process_image(image)
33
+
34
  messages = [
35
  {
36
  "role": "user",
37
  "content": [
38
  {
39
  "type": "image",
40
+ "image": processed_image,
41
  },
42
  {"type": "text", "text": query},
43
  ],
44
  }
45
  ]
46
 
47
+ inputs = processor(messages, return_tensors="pt")
 
 
 
 
 
 
 
 
 
48
  inputs = inputs.to("cpu")
49
 
50
+ output = vlm.generate(**inputs, max_new_tokens=512)
51
+ output_text = processor.decode(output[0], skip_special_tokens=True)
52
+
 
 
 
 
 
 
 
53
  if keyword:
54
  keyword_lower = keyword.lower()
55
  if keyword_lower in output_text.lower():
 
60
  else:
61
  return output_text
62
 
 
63
  def process_image(image, keyword=""):
64
+ # Resize image if it's too large
65
  max_size = 1024
66
  if max(image.size) > max_size:
67
  image.thumbnail((max_size, max_size))
68
  return ocr_image(image, keyword=keyword)
69
 
70
+ # Gradio interface:
71
  interface = gr.Interface(
72
  fn=process_image,
73
  inputs=[