jzhang533 commited on
Commit
f073f41
·
1 Parent(s): 1d14db4

Signed-off-by: Zhang Jun <jzhang533@gmail.com>

Files changed (2) hide show
  1. app.py +60 -77
  2. requirements.txt +1 -1
app.py CHANGED
@@ -17,23 +17,23 @@ processor = None
17
 
18
  def load_model():
19
  global model, processor
20
- try:
21
- print("Loading model...")
22
- model = AutoModelForCausalLM.from_pretrained(
23
- MODEL_PATH,
24
- trust_remote_code=True,
25
- torch_dtype=torch.bfloat16,
26
- device_map="auto",
27
- ).eval()
28
-
29
- processor = AutoProcessor.from_pretrained(
30
- MODEL_PATH, trust_remote_code=True, use_fast=True
31
- )
 
 
 
32
 
33
- print("Model loaded successfully!")
34
- except Exception as e:
35
- print(f"Error loading model: {e}")
36
- raise e
37
 
38
 
39
  # Load model on startup
@@ -53,70 +53,53 @@ def perform_ocr(image):
53
  if image is None:
54
  return "Please upload an image first."
55
 
56
- try:
57
- # Ensure model is on GPU
58
- if model.device.type == "cpu" and torch.cuda.is_available():
59
- print("Moving model to GPU...")
60
- model.to("cuda")
61
-
62
- # Convert to PIL Image if needed
63
- if not isinstance(image, Image.Image):
64
- image = Image.fromarray(image)
65
-
66
- # Ensure RGB format
67
- image = image.convert("RGB")
68
-
69
- # Prepare the prompt
70
- messages = [
71
- {
72
- "role": "user",
73
- "content": [
74
- {"type": "image", "image": image},
75
- {"type": "text", "text": "OCR:"},
76
- ],
77
- }
78
- ]
79
-
80
- # Process inputs
81
- text = processor.apply_chat_template(
82
- messages, tokenize=False, add_generation_prompt=True
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
- inputs = processor(text=[text], images=[image], return_tensors="pt")
85
-
86
- # Generate text
87
- with torch.inference_mode():
88
- device = next(model.parameters()).device
89
- inputs = inputs.to(device)
90
-
91
- # Extract input_ids and other tensors to avoid keyword argument issues
92
- input_ids_tensor = inputs.input_ids if hasattr(inputs, 'input_ids') else inputs.get('input_ids')
93
- pixel_values = inputs.pixel_values if hasattr(inputs, 'pixel_values') else inputs.get('pixel_values')
94
- attention_mask = inputs.attention_mask if hasattr(inputs, 'attention_mask') else inputs.get('attention_mask')
95
-
96
- generated_ids = model.generate(
97
- input_ids=input_ids_tensor,
98
- pixel_values=pixel_values,
99
- attention_mask=attention_mask,
100
- max_new_tokens=2048,
101
- do_sample=False,
102
- use_cache=True,
103
- )
104
-
105
- if "input_ids" in inputs:
106
- input_ids = inputs.input_ids
107
- else:
108
- print("inputs: # fallback", inputs)
109
- input_ids = inputs.inputs
110
-
111
- generated_ids_trimmed = [
112
- out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids)
113
- ]
114
- answer = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
115
 
116
- return answer
 
 
117
 
118
- except Exception as e:
119
- return f"Error during OCR: {e!s}"
120
 
121
 
122
  # Create Gradio interface
 
17
 
18
  def load_model():
19
  global model, processor
20
+ print("Loading model...")
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ MODEL_PATH,
23
+ trust_remote_code=True,
24
+ torch_dtype=torch.bfloat16,
25
+ device_map="auto",
26
+ ).eval()
27
+
28
+ processor = AutoProcessor.from_pretrained(
29
+ MODEL_PATH, trust_remote_code=True, use_fast=True
30
+ )
31
+
32
+ # Set pad_token_id to avoid warning during generation
33
+ if model.generation_config.pad_token_id is None:
34
+ model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
35
 
36
+ print("Model loaded successfully!")
 
 
 
37
 
38
 
39
  # Load model on startup
 
53
  if image is None:
54
  return "Please upload an image first."
55
 
56
+ # Ensure model is on GPU
57
+ if model.device.type == "cpu" and torch.cuda.is_available():
58
+ print("Moving model to GPU...")
59
+ model.to("cuda")
60
+
61
+ # Convert to PIL Image if needed
62
+ if not isinstance(image, Image.Image):
63
+ image = Image.fromarray(image)
64
+
65
+ # Ensure RGB format
66
+ image = image.convert("RGB")
67
+
68
+ # Prepare the prompt
69
+ messages = [
70
+ {
71
+ "role": "user",
72
+ "content": [
73
+ {"type": "image", "image": image},
74
+ {"type": "text", "text": "OCR:"},
75
+ ],
76
+ }
77
+ ]
78
+
79
+ # Process inputs
80
+ text = processor.apply_chat_template(
81
+ messages, tokenize=False, add_generation_prompt=True
82
+ )
83
+ inputs = processor(text=[text], images=[image], return_tensors="pt")
84
+ inputs = {
85
+ k: (v.to(model.device) if isinstance(v, torch.Tensor) else v)
86
+ for k, v in inputs.items()
87
+ }
88
+
89
+ # Generate text
90
+ with torch.inference_mode():
91
+ generated = model.generate(
92
+ **inputs,
93
+ max_new_tokens=2048,
94
+ do_sample=False,
95
+ use_cache=True,
96
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ input_length = inputs["input_ids"].shape[1]
99
+ generated_tokens = generated[:, input_length:]
100
+ answer = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0]
101
 
102
+ return answer
 
103
 
104
 
105
  # Create Gradio interface
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  torch>=2.0.0
2
- transformers
3
  accelerate
4
  pillow>=10.0.0
5
  einops
 
1
  torch>=2.0.0
2
+ transformers==4.57.1
3
  accelerate
4
  pillow>=10.0.0
5
  einops