Unique00225 commited on
Commit
71938a5
·
verified ·
1 Parent(s): bdecb08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -17
app.py CHANGED
@@ -1,24 +1,131 @@
1
- # trocr_infer.py -- paste into your app.py and call ocr_with_trocr(pil_image)
2
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  import torch
 
 
 
 
 
4
 
5
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
6
 
7
- # Model choices: "microsoft/trocr-base-printed" or "microsoft/trocr-base-handwritten"
8
- MODEL_NAME = "microsoft/trocr-base-printed"
 
 
9
 
10
- processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
11
- model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME).to(device)
12
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- def ocr_with_trocr(pil_image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  """
16
- Input: PIL.Image (RGB)
17
- Returns: recognized text string
18
  """
19
- # Preprocess
20
- pixel_values = processor(images=pil_image, return_tensors="pt").pixel_values.to(device)
21
- # Generate (greedy; tune generation params if desired)
22
- generated_ids = model.generate(pixel_values, max_length=128, num_beams=1)
23
- text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
24
- return text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForVision2Seq
2
+ from PIL import Image
3
  import torch
4
+ import io
5
+ import base64
6
+ from fastapi import FastAPI, UploadFile, File, HTTPException
7
+ from fastapi.responses import JSONResponse
8
+ import uvicorn
9
 
10
+ # Initialize FastAPI app
11
+ app = FastAPI(title="OLM OCR API", description="OCR using allenai/olmOCR-2-7B-1025-FP8")
12
 
13
+ # Global variables for model and processor
14
+ processor = None
15
+ model = None
16
+ device = None
17
 
18
+ def load_model():
19
+ """Load the model and processor"""
20
+ global processor, model, device
21
+
22
+ print("Loading processor...")
23
+ processor = AutoProcessor.from_pretrained("allenai/olmOCR-2-7B-1025-FP8")
24
+
25
+ print("Loading model...")
26
+ model = AutoModelForVision2Seq.from_pretrained(
27
+ "allenai/olmOCR-2-7B-1025-FP8",
28
+ torch_dtype=torch.float16,
29
+ device_map="auto"
30
+ )
31
+
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ print(f"Model loaded on device: {device}")
34
 
35
+ @app.on_event("startup")
36
+ async def startup_event():
37
+ """Load model on startup"""
38
+ load_model()
39
+
40
+ @app.get("/")
41
+ async def root():
42
+ return {"message": "OLM OCR API is running!", "model": "allenai/olmOCR-2-7B-1025-FP8"}
43
+
44
+ @app.get("/health")
45
+ async def health_check():
46
+ return {"status": "healthy", "model_loaded": model is not None}
47
+
48
+ @app.post("/ocr")
49
+ async def extract_text_from_image(file: UploadFile = File(...)):
50
  """
51
+ Extract text from uploaded image
 
52
  """
53
+ try:
54
+ # Check if file is an image
55
+ if not file.content_type.startswith('image/'):
56
+ raise HTTPException(status_code=400, detail="File must be an image")
57
+
58
+ # Read image file
59
+ contents = await file.read()
60
+ image = Image.open(io.BytesIO(contents)).convert('RGB')
61
+
62
+ # Process image and generate text
63
+ inputs = processor(images=image, return_tensors="pt").to(device)
64
+
65
+ with torch.no_grad():
66
+ generated_ids = model.generate(
67
+ **inputs,
68
+ max_new_tokens=1024,
69
+ do_sample=False,
70
+ )
71
+
72
+ # Decode the generated text
73
+ generated_text = processor.batch_decode(
74
+ generated_ids,
75
+ skip_special_tokens=True
76
+ )[0]
77
+
78
+ return JSONResponse({
79
+ "success": True,
80
+ "extracted_text": generated_text,
81
+ "filename": file.filename,
82
+ "file_size": len(contents)
83
+ })
84
+
85
+ except Exception as e:
86
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
87
+
88
+ @app.post("/ocr/base64")
89
+ async def extract_text_from_base64(data: dict):
90
+ """
91
+ Extract text from base64 encoded image
92
+ """
93
+ try:
94
+ if 'image' not in data:
95
+ raise HTTPException(status_code=400, detail="Missing 'image' field in request")
96
+
97
+ # Decode base64 image
98
+ image_data = base64.b64decode(data['image'])
99
+ image = Image.open(io.BytesIO(image_data)).convert('RGB')
100
+
101
+ # Process image and generate text
102
+ inputs = processor(images=image, return_tensors="pt").to(device)
103
+
104
+ with torch.no_grad():
105
+ generated_ids = model.generate(
106
+ **inputs,
107
+ max_new_tokens=1024,
108
+ do_sample=False,
109
+ )
110
+
111
+ # Decode the generated text
112
+ generated_text = processor.batch_decode(
113
+ generated_ids,
114
+ skip_special_tokens=True
115
+ )[0]
116
+
117
+ return JSONResponse({
118
+ "success": True,
119
+ "extracted_text": generated_text
120
+ })
121
+
122
+ except Exception as e:
123
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
124
+
125
+ if __name__ == "__main__":
126
+ uvicorn.run(
127
+ "app:app",
128
+ host="0.0.0.0",
129
+ port=8000,
130
+ reload=True
131
+ )