Unique00225 commited on
Commit
f4d5db9
·
verified ·
1 Parent(s): e49c2db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -45
app.py CHANGED
@@ -2,69 +2,70 @@ import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForVision2Seq
3
  import torch
4
  from PIL import Image
5
- import os
6
 
7
- # Load model directly
 
 
 
 
8
  def load_model():
9
- processor = AutoProcessor.from_pretrained("allenai/olmOCR-2-7B-1025-FP8")
10
- model = AutoModelForVision2Seq.from_pretrained(
11
- "allenai/olmOCR-2-7B-1025-FP8",
12
- torch_dtype=torch.float16,
13
- device_map="auto"
14
- )
15
- return processor, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # Load model once at startup
18
  processor, model = load_model()
19
 
20
  def extract_text_from_image(image):
21
- """
22
- Extract text from image using OLM OCR model
23
- """
24
  try:
25
- # Convert to RGB if needed
26
- if image.mode != 'RGB':
27
- image = image.convert('RGB')
28
 
29
- # Process image and generate text
30
- inputs = processor(images=image, return_tensors="pt")
 
31
 
 
32
  with torch.no_grad():
33
- generated_ids = model.generate(
34
  **inputs,
35
- max_new_tokens=1024,
36
  do_sample=False,
 
37
  )
38
 
39
- # Decode the generated text
40
- generated_text = processor.batch_decode(
41
- generated_ids,
42
- skip_special_tokens=True
43
- )[0]
44
-
45
- return generated_text
46
 
47
  except Exception as e:
48
- return f"Error processing image: {str(e)}"
49
 
50
- # Create Gradio interface
51
  demo = gr.Interface(
52
- fn=extract_text_from_image,
53
- inputs=gr.Image(type="pil", label="Upload Image"),
54
- outputs=gr.Textbox(label="Extracted Text", lines=10),
55
- title="OLM OCR Text Extraction",
56
- description="Extract text from images using allenai/olmOCR-2-7B-1025-FP8 model",
57
- examples=[
58
- ["example1.jpg"], # You can add example images
59
- ["example2.jpg"],
60
- ],
61
- allow_flagging="never"
62
  )
63
 
64
- # For Hugging Face Spaces
65
  if __name__ == "__main__":
66
- demo.launch(
67
- server_name="0.0.0.0",
68
- server_port=7860,
69
- share=False
70
- )
 
2
  from transformers import AutoProcessor, AutoModelForVision2Seq
3
  import torch
4
  from PIL import Image
 
5
 
6
+ # Check if we have enough memory, otherwise use CPU
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
9
+
10
+ @gr.cache_resource
11
  def load_model():
12
+ try:
13
+ print("Loading OLM OCR model...")
14
+
15
+ # Load with optimizations for limited resources
16
+ processor = AutoProcessor.from_pretrained("allenai/olmOCR-2-7B-1025-FP8")
17
+ model = AutoModelForVision2Seq.from_pretrained(
18
+ "allenai/olmOCR-2-7B-1025-FP8",
19
+ torch_dtype=torch_dtype,
20
+ device_map="auto" if device == "cuda" else None,
21
+ low_cpu_mem_usage=True
22
+ )
23
+
24
+ if device == "cpu":
25
+ model = model.to(device)
26
+
27
+ print("Model loaded successfully!")
28
+ return processor, model
29
+
30
+ except Exception as e:
31
+ print(f"Error loading model: {e}")
32
+ return None, None
33
 
 
34
  processor, model = load_model()
35
 
36
  def extract_text_from_image(image):
37
+ if processor is None or model is None:
38
+ return "Model failed to load. The model might be too large for this environment."
39
+
40
  try:
41
+ if image is None:
42
+ return "Please upload an image first."
 
43
 
44
+ # Convert and process image
45
+ image = image.convert('RGB')
46
+ inputs = processor(images=image, return_tensors="pt").to(device)
47
 
48
+ # Generate with optimizations
49
  with torch.no_grad():
50
+ outputs = model.generate(
51
  **inputs,
52
+ max_new_tokens=256, # Reduced for faster processing
53
  do_sample=False,
54
+ num_beams=1 # Faster but less accurate
55
  )
56
 
57
+ text = processor.decode(outputs[0], skip_special_tokens=True)
58
+ return text
 
 
 
 
 
59
 
60
  except Exception as e:
61
+ return f"Error: {str(e)}"
62
 
 
63
  demo = gr.Interface(
64
+ extract_text_from_image,
65
+ gr.Image(type="pil"),
66
+ gr.Textbox(lines=5),
67
+ title="OLM OCR"
 
 
 
 
 
 
68
  )
69
 
 
70
  if __name__ == "__main__":
71
+ demo.launch()