prithivMLmods commited on
Commit
2e447a3
·
verified ·
1 Parent(s): 0348aed
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -134,6 +134,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
134
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
135
 
136
 
 
137
  MODEL_ID_M = "nanonets/Nanonets-OCR2-3B"
138
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
139
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -142,7 +143,18 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
142
  torch_dtype=torch.float16
143
  ).to(device).eval()
144
 
 
 
 
 
 
 
 
 
 
 
145
 
 
146
  MODEL_PATH_D = model_path_d_local
147
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
148
  model_d = AutoModelForCausalLM.from_pretrained(
@@ -153,6 +165,7 @@ model_d = AutoModelForCausalLM.from_pretrained(
153
  trust_remote_code=True
154
  ).eval()
155
 
 
156
  MODEL_ID_P = "strangervisionhf/paddle"
157
  processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
158
  model_p = AutoModelForCausalLM.from_pretrained(
@@ -172,6 +185,8 @@ def generate_image(model_name: str, text: str, image: Image.Image,
172
  """Generate responses for image input using the selected model."""
173
  if model_name == "Nanonets-OCR2-3B":
174
  processor, model = processor_m, model_m
 
 
175
  elif model_name == "Dots.OCR":
176
  processor, model = processor_d, model_d
177
  elif model_name == "PaddleOCR":
@@ -186,9 +201,6 @@ def generate_image(model_name: str, text: str, image: Image.Image,
186
 
187
  images = [image.convert("RGB")]
188
 
189
- # --- ERROR FIX ---
190
- # PaddleOCR's processor expects a different message format than the others.
191
- # Its chat template expects the 'content' to be a simple string, not a list.
192
  if model_name == "PaddleOCR":
193
  messages = [
194
  {"role": "user", "content": text}
@@ -254,7 +266,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
254
  formatted_output = gr.Markdown(label="Formatted Result")
255
 
256
  model_choice = gr.Radio(
257
- choices=["Nanonets-OCR2-3B", "Dots.OCR", "PaddleOCR"],
258
  label="Select Model",
259
  value="Nanonets-OCR2-3B"
260
  )
 
134
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
135
 
136
 
137
+ # Load Nanonets-OCR2-3B
138
  MODEL_ID_M = "nanonets/Nanonets-OCR2-3B"
139
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
140
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
143
  torch_dtype=torch.float16
144
  ).to(device).eval()
145
 
146
+ # Load Nanonets-OCR2-1.5B-exp
147
+ MODEL_ID_N = "nanonets/Nanonets-OCR2-1.5B-exp"
148
+ processor_n = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True)
149
+ model_n = Qwen2_5_VLForConditionalGeneration.from_pretrained(
150
+ MODEL_ID_N,
151
+ trust_remote_code=True,
152
+ torch_dtype=torch.float16,
153
+ attn_implementation="flash_attention_2"
154
+ ).to(device).eval()
155
+
156
 
157
+ # Load Dots.OCR from the local, patched directory
158
  MODEL_PATH_D = model_path_d_local
159
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
160
  model_d = AutoModelForCausalLM.from_pretrained(
 
165
  trust_remote_code=True
166
  ).eval()
167
 
168
+ # Load PaddleOCR
169
  MODEL_ID_P = "strangervisionhf/paddle"
170
  processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
171
  model_p = AutoModelForCausalLM.from_pretrained(
 
185
  """Generate responses for image input using the selected model."""
186
  if model_name == "Nanonets-OCR2-3B":
187
  processor, model = processor_m, model_m
188
+ elif model_name == "Nanonets-OCR2-1.5B-exp":
189
+ processor, model = processor_n, model_n
190
  elif model_name == "Dots.OCR":
191
  processor, model = processor_d, model_d
192
  elif model_name == "PaddleOCR":
 
201
 
202
  images = [image.convert("RGB")]
203
 
 
 
 
204
  if model_name == "PaddleOCR":
205
  messages = [
206
  {"role": "user", "content": text}
 
266
  formatted_output = gr.Markdown(label="Formatted Result")
267
 
268
  model_choice = gr.Radio(
269
+ choices=["Nanonets-OCR2-3B", "Nanonets-OCR2-1.5B-exp", "Dots.OCR", "PaddleOCR"],
270
  label="Select Model",
271
  value="Nanonets-OCR2-3B"
272
  )