Unique00225 commited on
Commit
a2654bf
·
verified ·
1 Parent(s): 71938a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -94
app.py CHANGED
@@ -1,66 +1,33 @@
 
1
  from transformers import AutoProcessor, AutoModelForVision2Seq
2
- from PIL import Image
3
  import torch
4
- import io
5
- import base64
6
- from fastapi import FastAPI, UploadFile, File, HTTPException
7
- from fastapi.responses import JSONResponse
8
- import uvicorn
9
-
10
- # Initialize FastAPI app
11
- app = FastAPI(title="OLM OCR API", description="OCR using allenai/olmOCR-2-7B-1025-FP8")
12
-
13
- # Global variables for model and processor
14
- processor = None
15
- model = None
16
- device = None
17
 
 
18
  def load_model():
19
- """Load the model and processor"""
20
- global processor, model, device
21
-
22
- print("Loading processor...")
23
  processor = AutoProcessor.from_pretrained("allenai/olmOCR-2-7B-1025-FP8")
24
-
25
- print("Loading model...")
26
  model = AutoModelForVision2Seq.from_pretrained(
27
  "allenai/olmOCR-2-7B-1025-FP8",
28
  torch_dtype=torch.float16,
29
  device_map="auto"
30
  )
31
-
32
- device = "cuda" if torch.cuda.is_available() else "cpu"
33
- print(f"Model loaded on device: {device}")
34
 
35
- @app.on_event("startup")
36
- async def startup_event():
37
- """Load model on startup"""
38
- load_model()
39
 
40
- @app.get("/")
41
- async def root():
42
- return {"message": "OLM OCR API is running!", "model": "allenai/olmOCR-2-7B-1025-FP8"}
43
-
44
- @app.get("/health")
45
- async def health_check():
46
- return {"status": "healthy", "model_loaded": model is not None}
47
-
48
- @app.post("/ocr")
49
- async def extract_text_from_image(file: UploadFile = File(...)):
50
  """
51
- Extract text from uploaded image
52
  """
53
  try:
54
- # Check if file is an image
55
- if not file.content_type.startswith('image/'):
56
- raise HTTPException(status_code=400, detail="File must be an image")
57
-
58
- # Read image file
59
- contents = await file.read()
60
- image = Image.open(io.BytesIO(contents)).convert('RGB')
61
 
62
  # Process image and generate text
63
- inputs = processor(images=image, return_tensors="pt").to(device)
64
 
65
  with torch.no_grad():
66
  generated_ids = model.generate(
@@ -75,57 +42,29 @@ async def extract_text_from_image(file: UploadFile = File(...)):
75
  skip_special_tokens=True
76
  )[0]
77
 
78
- return JSONResponse({
79
- "success": True,
80
- "extracted_text": generated_text,
81
- "filename": file.filename,
82
- "file_size": len(contents)
83
- })
84
 
85
  except Exception as e:
86
- raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
87
 
88
- @app.post("/ocr/base64")
89
- async def extract_text_from_base64(data: dict):
90
- """
91
- Extract text from base64 encoded image
92
- """
93
- try:
94
- if 'image' not in data:
95
- raise HTTPException(status_code=400, detail="Missing 'image' field in request")
96
-
97
- # Decode base64 image
98
- image_data = base64.b64decode(data['image'])
99
- image = Image.open(io.BytesIO(image_data)).convert('RGB')
100
-
101
- # Process image and generate text
102
- inputs = processor(images=image, return_tensors="pt").to(device)
103
-
104
- with torch.no_grad():
105
- generated_ids = model.generate(
106
- **inputs,
107
- max_new_tokens=1024,
108
- do_sample=False,
109
- )
110
-
111
- # Decode the generated text
112
- generated_text = processor.batch_decode(
113
- generated_ids,
114
- skip_special_tokens=True
115
- )[0]
116
-
117
- return JSONResponse({
118
- "success": True,
119
- "extracted_text": generated_text
120
- })
121
-
122
- except Exception as e:
123
- raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
124
 
 
125
  if __name__ == "__main__":
126
- uvicorn.run(
127
- "app:app",
128
- host="0.0.0.0",
129
- port=8000,
130
- reload=True
131
  )
 
1
+ import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForVision2Seq
 
3
  import torch
4
+ from PIL import Image
5
+ import os
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # Load model directly
8
  def load_model():
 
 
 
 
9
  processor = AutoProcessor.from_pretrained("allenai/olmOCR-2-7B-1025-FP8")
 
 
10
  model = AutoModelForVision2Seq.from_pretrained(
11
  "allenai/olmOCR-2-7B-1025-FP8",
12
  torch_dtype=torch.float16,
13
  device_map="auto"
14
  )
15
+ return processor, model
 
 
16
 
17
+ # Load model once at startup
18
+ processor, model = load_model()
 
 
19
 
20
+ def extract_text_from_image(image):
 
 
 
 
 
 
 
 
 
21
  """
22
+ Extract text from image using OLM OCR model
23
  """
24
  try:
25
+ # Convert to RGB if needed
26
+ if image.mode != 'RGB':
27
+ image = image.convert('RGB')
 
 
 
 
28
 
29
  # Process image and generate text
30
+ inputs = processor(images=image, return_tensors="pt")
31
 
32
  with torch.no_grad():
33
  generated_ids = model.generate(
 
42
  skip_special_tokens=True
43
  )[0]
44
 
45
+ return generated_text
 
 
 
 
 
46
 
47
  except Exception as e:
48
+ return f"Error processing image: {str(e)}"
49
 
50
+ # Create Gradio interface
51
+ demo = gr.Interface(
52
+ fn=extract_text_from_image,
53
+ inputs=gr.Image(type="pil", label="Upload Image"),
54
+ outputs=gr.Textbox(label="Extracted Text", lines=10),
55
+ title="OLM OCR Text Extraction",
56
+ description="Extract text from images using allenai/olmOCR-2-7B-1025-FP8 model",
57
+ examples=[
58
+ ["example1.jpg"], # You can add example images
59
+ ["example2.jpg"],
60
+ ],
61
+ allow_flagging="never"
62
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ # For Hugging Face Spaces
65
  if __name__ == "__main__":
66
+ demo.launch(
67
+ server_name="0.0.0.0",
68
+ server_port=7860,
69
+ share=False
 
70
  )