npaleti2002 commited on
Commit
70d4250
·
verified ·
1 Parent(s): b3c036b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -64
app.py CHANGED
@@ -11,26 +11,25 @@ from transformers import (
11
  VisionEncoderDecoderModel,
12
  TrOCRProcessor,
13
  )
14
-
15
  from huggingface_hub import login
16
- import os
17
 
 
18
  hf_token = os.getenv("HF_TOKEN")
19
  if hf_token:
20
- login(token=hf_token)
21
-
 
 
22
 
23
  TITLE = "Picture to Problem Solver"
24
  DESCRIPTION = (
25
- "Upload an image. I’ll read the text and a math/code/science-trained AI will help answer your question."
26
- "\n\n⚠️ Note: facebook/MobileLLM-R1-950M is released for non-commercial research use."
27
  )
28
 
29
  # ---------------------------
30
  # Load OCR (TrOCR)
31
  # ---------------------------
32
- # Use the "printed" variant for typed/scanned text.
33
- # If you expect handwriting, switch to: microsoft/trocr-base-handwritten
34
  OCR_MODEL_ID = os.getenv("OCR_MODEL_ID", "microsoft/trocr-base-printed")
35
  ocr_processor = TrOCRProcessor.from_pretrained(OCR_MODEL_ID)
36
  ocr_model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_ID)
@@ -41,15 +40,17 @@ ocr_model.eval()
41
  # ---------------------------
42
  LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "facebook/MobileLLM-R1-950M")
43
 
44
- # Device & dtype selection that plays nice on Spaces
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
- # Keep dtype conservative to avoid OOM on CPU Spaces
47
- torch_dtype = torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) else torch.float32
48
 
49
  llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True)
 
 
 
 
50
  llm_model = AutoModelForCausalLM.from_pretrained(
51
  LLM_MODEL_ID,
52
- torch_dtype=torch_dtype,
53
  low_cpu_mem_usage=True,
54
  device_map="auto" if device == "cuda" else None,
55
  )
@@ -57,19 +58,16 @@ llm_model.eval()
57
  if device == "cpu":
58
  llm_model.to(device)
59
 
60
- # Ensure EOS/BOS tokens exist
61
  eos_token_id = llm_tokenizer.eos_token_id
62
  if eos_token_id is None:
63
- # Fallback: add one if truly missing (rare)
64
  llm_tokenizer.add_special_tokens({"eos_token": "</s>"})
65
  llm_model.resize_token_embeddings(len(llm_tokenizer))
66
  eos_token_id = llm_tokenizer.eos_token_id
67
 
68
-
69
  SYSTEM_INSTRUCTION = (
70
  "You are a precise, step-by-step technical assistant. "
71
  "You excel at math, programming (Python, C++), and scientific reasoning. "
72
- "Be concise, show steps when helpful, and avoid hallucinations. "
73
  )
74
 
75
  USER_PROMPT_TEMPLATE = (
@@ -85,13 +83,11 @@ def build_prompt(ocr_text: str, user_question: str) -> str:
85
  q = f"User question: {user_question.strip()}"
86
  else:
87
  q = "Please summarize the key information and explain any math/code/science content."
88
-
89
  return f"{SYSTEM_INSTRUCTION}\n\n" + USER_PROMPT_TEMPLATE.format(
90
- ocr_text=ocr_text.strip() if ocr_text else "(no text detected)",
91
  question_hint=q,
92
  )
93
 
94
-
95
  @torch.inference_mode()
96
  def run_pipeline(
97
  image: Image.Image,
@@ -100,51 +96,43 @@ def run_pipeline(
100
  temperature: float = 0.2,
101
  top_p: float = 0.9,
102
  ) -> Tuple[str, str]:
103
- """
104
- Returns:
105
- (extracted_text, model_answer)
106
- """
107
  if image is None:
108
  return "", "Please upload an image."
109
 
110
  # --- OCR ---
111
- # TrOCR wants pixel_values prepared by its processor
112
- pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values
113
- with torch.inference_mode():
114
  ocr_ids = ocr_model.generate(pixel_values, max_new_tokens=256)
115
- extracted_text = ocr_processor.batch_decode(ocr_ids, skip_special_tokens=True)[0].strip()
 
 
116
 
117
- # --- Build prompt for LLM ---
118
  prompt = build_prompt(extracted_text, question)
119
 
120
  # --- LLM Inference ---
121
- inputs = llm_tokenizer(prompt, return_tensors="pt")
122
- if device == "cuda":
123
- inputs = {k: v.to(llm_model.device) for k, v in inputs.items()}
124
- else:
125
- inputs = {k: v.to(device) for k, v in inputs.items()}
126
-
127
- generation_kwargs = dict(
128
- max_new_tokens=max_new_tokens,
129
- do_sample=True if temperature > 0 else False,
130
- temperature=max(0.0, min(temperature, 1.5)),
131
- top_p=max(0.1, min(top_p, 1.0)),
132
- eos_token_id=eos_token_id,
133
- pad_token_id=llm_tokenizer.eos_token_id, # keep decoding clean
134
- )
135
-
136
- output_ids = llm_model.generate(**inputs, **generation_kwargs)
137
- # We only want the newly generated part for readability
138
- gen_text = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
139
 
140
- # Optional: strip the original prompt if the model echoes it
141
- if gen_text.startswith(prompt):
142
- gen_text = gen_text[len(prompt):].lstrip()
 
 
 
143
 
144
  return extracted_text, gen_text
145
 
146
-
147
-
148
  def demo_ui():
149
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
150
  gr.Markdown(f"# {TITLE}")
@@ -155,7 +143,7 @@ def demo_ui():
155
  image_input = gr.Image(type="pil", label="Upload an image")
156
  question = gr.Textbox(
157
  label="Ask a question about the image (optional)",
158
- placeholder="e.g., Summarize, extract key numbers, explain this formula, write Python to do X...",
159
  )
160
  with gr.Accordion("Generation settings (advanced)", open=False):
161
  max_new_tokens = gr.Slider(32, 1024, value=256, step=16, label="max_new_tokens")
@@ -174,17 +162,6 @@ def demo_ui():
174
  outputs=[ocr_out, llm_out],
175
  )
176
 
177
- gr.Examples(
178
- label="Try these sample prompts (use with your own images)",
179
- examples=[
180
- ["", "Summarize the document."],
181
- ["", "Extract all dates and amounts, then total the amounts."],
182
- ["", "Explain the equation and solve for x."],
183
- ["", "Convert the pseudocode in the image to Python."],
184
- ],
185
- inputs=[image_input, question],
186
- )
187
-
188
  gr.Markdown(
189
  "—\n**Licensing reminder:** facebook/MobileLLM-R1-950M is typically released for non-commercial research use. "
190
  "Review the model card before production use."
@@ -192,7 +169,6 @@ def demo_ui():
192
 
193
  return demo
194
 
195
-
196
  if __name__ == "__main__":
197
  demo = demo_ui()
198
  demo.launch()
 
11
  VisionEncoderDecoderModel,
12
  TrOCRProcessor,
13
  )
 
14
  from huggingface_hub import login
 
15
 
16
+ # Optional: login via repo secret HF_TOKEN in Spaces
17
  hf_token = os.getenv("HF_TOKEN")
18
  if hf_token:
19
+ try:
20
+ login(token=hf_token)
21
+ except Exception:
22
+ pass
23
 
24
  TITLE = "Picture to Problem Solver"
25
  DESCRIPTION = (
26
+ "Upload an image. I’ll read the text and a math/code/science-trained AI will help answer your question.\n\n"
27
+ "⚠️ Note: facebook/MobileLLM-R1-950M is released for non-commercial research use."
28
  )
29
 
30
  # ---------------------------
31
  # Load OCR (TrOCR)
32
  # ---------------------------
 
 
33
  OCR_MODEL_ID = os.getenv("OCR_MODEL_ID", "microsoft/trocr-base-printed")
34
  ocr_processor = TrOCRProcessor.from_pretrained(OCR_MODEL_ID)
35
  ocr_model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_ID)
 
40
  # ---------------------------
41
  LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "facebook/MobileLLM-R1-950M")
42
 
 
43
  device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ dtype = torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) else torch.float32
 
45
 
46
  llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True)
47
+ # Ensure pad token exists to prevent warnings during generation
48
+ if llm_tokenizer.pad_token_id is None and llm_tokenizer.eos_token_id is not None:
49
+ llm_tokenizer.pad_token = llm_tokenizer.eos_token
50
+
51
  llm_model = AutoModelForCausalLM.from_pretrained(
52
  LLM_MODEL_ID,
53
+ dtype=dtype,
54
  low_cpu_mem_usage=True,
55
  device_map="auto" if device == "cuda" else None,
56
  )
 
58
  if device == "cpu":
59
  llm_model.to(device)
60
 
 
61
  eos_token_id = llm_tokenizer.eos_token_id
62
  if eos_token_id is None:
 
63
  llm_tokenizer.add_special_tokens({"eos_token": "</s>"})
64
  llm_model.resize_token_embeddings(len(llm_tokenizer))
65
  eos_token_id = llm_tokenizer.eos_token_id
66
 
 
67
  SYSTEM_INSTRUCTION = (
68
  "You are a precise, step-by-step technical assistant. "
69
  "You excel at math, programming (Python, C++), and scientific reasoning. "
70
+ "Be concise, show steps when helpful, and avoid hallucinations."
71
  )
72
 
73
  USER_PROMPT_TEMPLATE = (
 
83
  q = f"User question: {user_question.strip()}"
84
  else:
85
  q = "Please summarize the key information and explain any math/code/science content."
 
86
  return f"{SYSTEM_INSTRUCTION}\n\n" + USER_PROMPT_TEMPLATE.format(
87
+ ocr_text=(ocr_text or "").strip() or "(no text detected)",
88
  question_hint=q,
89
  )
90
 
 
91
  @torch.inference_mode()
92
  def run_pipeline(
93
  image: Image.Image,
 
96
  temperature: float = 0.2,
97
  top_p: float = 0.9,
98
  ) -> Tuple[str, str]:
 
 
 
 
99
  if image is None:
100
  return "", "Please upload an image."
101
 
102
  # --- OCR ---
103
+ try:
104
+ pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values
 
105
  ocr_ids = ocr_model.generate(pixel_values, max_new_tokens=256)
106
+ extracted_text = ocr_processor.batch_decode(ocr_ids, skip_special_tokens=True)[0].strip()
107
+ except Exception as e:
108
+ return "", f"OCR failed: {e}"
109
 
110
+ # --- Build prompt ---
111
  prompt = build_prompt(extracted_text, question)
112
 
113
  # --- LLM Inference ---
114
+ try:
115
+ inputs = llm_tokenizer(prompt, return_tensors="pt")
116
+ inputs = {k: v.to(llm_model.device if device == "cuda" else device) for k, v in inputs.items()}
117
+
118
+ generation_kwargs = dict(
119
+ max_new_tokens=max_new_tokens,
120
+ do_sample=temperature > 0,
121
+ temperature=max(0.0, min(temperature, 1.5)),
122
+ top_p=max(0.1, min(top_p, 1.0)),
123
+ eos_token_id=eos_token_id,
124
+ pad_token_id=llm_tokenizer.pad_token_id if llm_tokenizer.pad_token_id is not None else eos_token_id,
125
+ )
 
 
 
 
 
 
126
 
127
+ output_ids = llm_model.generate(**inputs, **generation_kwargs)
128
+ gen_text = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
129
+ if gen_text.startswith(prompt):
130
+ gen_text = gen_text[len(prompt):].lstrip()
131
+ except Exception as e:
132
+ gen_text = f"LLM inference failed: {e}"
133
 
134
  return extracted_text, gen_text
135
 
 
 
136
  def demo_ui():
137
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
138
  gr.Markdown(f"# {TITLE}")
 
143
  image_input = gr.Image(type="pil", label="Upload an image")
144
  question = gr.Textbox(
145
  label="Ask a question about the image (optional)",
146
+ placeholder="e.g., Summarize, extract key numbers, explain this formula, convert code to Python...",
147
  )
148
  with gr.Accordion("Generation settings (advanced)", open=False):
149
  max_new_tokens = gr.Slider(32, 1024, value=256, step=16, label="max_new_tokens")
 
162
  outputs=[ocr_out, llm_out],
163
  )
164
 
 
 
 
 
 
 
 
 
 
 
 
165
  gr.Markdown(
166
  "—\n**Licensing reminder:** facebook/MobileLLM-R1-950M is typically released for non-commercial research use. "
167
  "Review the model card before production use."
 
169
 
170
  return demo
171
 
 
172
  if __name__ == "__main__":
173
  demo = demo_ui()
174
  demo.launch()