IFMedTechdemo commited on
Commit
acb34dc
·
verified ·
1 Parent(s): d6f6041

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -11,7 +11,7 @@ from transformers import (
11
  AutoModel,
12
  AutoModelForCausalLM,
13
  AutoTokenizer,
14
- Qwen3VLForConditionalGeneration,
15
  Qwen2_5_VLForConditionalGeneration,
16
  TextIteratorStreamer
17
  )
@@ -21,14 +21,14 @@ import time
21
  # Device setup
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
- # Load Chandra-OCR
25
  MODEL_ID_V = "datalab-to/chandra"
26
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
27
- model_v = Qwen3VLForConditionalGeneration.from_pretrained(
28
  MODEL_ID_V,
29
  trust_remote_code=True,
30
  torch_dtype=torch.float16,
31
- attn_implementation="sdpa" # Use PyTorch's native scaled dot product attention
32
  ).to(device).eval()
33
 
34
  # Load Nanonets-OCR2-3B
@@ -38,15 +38,15 @@ model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
38
  MODEL_ID_X,
39
  trust_remote_code=True,
40
  torch_dtype=torch.float16,
41
- attn_implementation="sdpa" # Use PyTorch's native attention
42
  ).to(device).eval()
43
 
44
- # Load Dots.OCR - REMOVE flash_attention_2 parameter
45
  MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix"
46
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
47
  model_d = AutoModelForCausalLM.from_pretrained(
48
  MODEL_PATH_D,
49
- attn_implementation="sdpa", # Changed from flash_attention_2
50
  torch_dtype=torch.bfloat16,
51
  device_map="auto",
52
  trust_remote_code=True
@@ -59,15 +59,15 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
59
  MODEL_ID_M,
60
  trust_remote_code=True,
61
  torch_dtype=torch.bfloat16,
62
- attn_implementation="sdpa" # Use PyTorch's native attention
63
  ).to(device).eval()
64
 
65
- # Load DeepSeek-OCR - REMOVE flash_attention_2 parameter
66
  MODEL_ID_DS = "deepseek-ai/DeepSeek-OCR"
67
  tokenizer_ds = AutoTokenizer.from_pretrained(MODEL_ID_DS, trust_remote_code=True)
68
  model_ds = AutoModel.from_pretrained(
69
  MODEL_ID_DS,
70
- attn_implementation="sdpa", # Changed from flash_attention_2
71
  trust_remote_code=True,
72
  use_safetensors=True
73
  ).eval().to(device).to(torch.bfloat16)
 
11
  AutoModel,
12
  AutoModelForCausalLM,
13
  AutoTokenizer,
14
+ Qwen2VLForConditionalGeneration, # Changed from Qwen3VL
15
  Qwen2_5_VLForConditionalGeneration,
16
  TextIteratorStreamer
17
  )
 
21
  # Device setup
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
+ # Load Chandra-OCR (uses Qwen2.5-VL architecture)
25
  MODEL_ID_V = "datalab-to/chandra"
26
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
27
+ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained( # Changed to Qwen2_5
28
  MODEL_ID_V,
29
  trust_remote_code=True,
30
  torch_dtype=torch.float16,
31
+ attn_implementation="sdpa"
32
  ).to(device).eval()
33
 
34
  # Load Nanonets-OCR2-3B
 
38
  MODEL_ID_X,
39
  trust_remote_code=True,
40
  torch_dtype=torch.float16,
41
+ attn_implementation="sdpa"
42
  ).to(device).eval()
43
 
44
+ # Load Dots.OCR
45
  MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix"
46
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
47
  model_d = AutoModelForCausalLM.from_pretrained(
48
  MODEL_PATH_D,
49
+ attn_implementation="sdpa",
50
  torch_dtype=torch.bfloat16,
51
  device_map="auto",
52
  trust_remote_code=True
 
59
  MODEL_ID_M,
60
  trust_remote_code=True,
61
  torch_dtype=torch.bfloat16,
62
+ attn_implementation="sdpa"
63
  ).to(device).eval()
64
 
65
+ # Load DeepSeek-OCR
66
  MODEL_ID_DS = "deepseek-ai/DeepSeek-OCR"
67
  tokenizer_ds = AutoTokenizer.from_pretrained(MODEL_ID_DS, trust_remote_code=True)
68
  model_ds = AutoModel.from_pretrained(
69
  MODEL_ID_DS,
70
+ attn_implementation="sdpa",
71
  trust_remote_code=True,
72
  use_safetensors=True
73
  ).eval().to(device).to(torch.bfloat16)