dcorcoran commited on
Commit
f965aa3
·
1 Parent(s): bc27ad8

Added OCR region drawing tab to visualize what the OCR is looking at

Browse files
Files changed (2) hide show
  1. app/main.py +37 -2
  2. app/services/ocr_service.py +82 -9
app/main.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, File, UploadFile
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import JSONResponse
4
 
5
  from contextlib import asynccontextmanager
6
  from PIL import Image
@@ -35,6 +35,7 @@ app.add_middleware(
35
  # ----- ROUTES -------
36
  # --------------------
37
 
 
38
  @app.on_event("startup")
39
  async def startup_event():
40
  """Load models and indexes at startup."""
@@ -44,11 +45,13 @@ async def startup_event():
44
  print("Models and indexes loaded successfully.")
45
 
46
 
 
47
  @app.get("/health")
48
  def health():
49
  return {"status": "ok"}
50
 
51
 
 
52
  @app.post("/predict", response_model=CardResponse)
53
  async def predict(file: UploadFile = File(...)):
54
  try:
@@ -74,6 +77,8 @@ async def predict(file: UploadFile = File(...)):
74
  except Exception as e:
75
  return JSONResponse(status_code=500, content={"error": str(e)})
76
 
 
 
77
  @app.get("/cards")
78
  def get_cards(limit: int = 100):
79
  try:
@@ -81,8 +86,38 @@ def get_cards(limit: int = 100):
81
  return cards
82
  except Exception as e:
83
  return JSONResponse(status_code=500, content={"error": str(e)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  import uvicorn
86
 
87
  if __name__ == "__main__":
88
- uvicorn.run("app.main:app", host="0.0.0.0", port=7860)
 
 
1
  from fastapi import FastAPI, File, UploadFile
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse, StreamingResponse
4
 
5
  from contextlib import asynccontextmanager
6
  from PIL import Image
 
35
  # ----- ROUTES -------
36
  # --------------------
37
 
38
+ # Actions done on app startup
39
  @app.on_event("startup")
40
  async def startup_event():
41
  """Load models and indexes at startup."""
 
45
  print("Models and indexes loaded successfully.")
46
 
47
 
48
+ # Endpoint to signify API health
49
  @app.get("/health")
50
  def health():
51
  return {"status": "ok"}
52
 
53
 
54
+ # Endpoint to extract information from pokemon cards
55
  @app.post("/predict", response_model=CardResponse)
56
  async def predict(file: UploadFile = File(...)):
57
  try:
 
77
  except Exception as e:
78
  return JSONResponse(status_code=500, content={"error": str(e)})
79
 
80
+
81
+ # Endpoint to display a sample of cards in the database
82
  @app.get("/cards")
83
  def get_cards(limit: int = 100):
84
  try:
 
86
  return cards
87
  except Exception as e:
88
  return JSONResponse(status_code=500, content={"error": str(e)})
89
+
90
+
91
+ # Endpoint to display the original card and its OCR cropped boxes
92
+ @app.post("/visualize")
93
+ async def visualize(file: UploadFile = File(...)):
94
+ try:
95
+ # Read the raw bytes from the uploaded file
96
+ image_bytes = await file.read()
97
+
98
+ # Decode bytes into a Pillow Image and normalize to RGB color space
99
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
100
+
101
+ # Delegate to OCRService to draw colored bounding boxes over each crop region
102
+ annotated = app.state.ocr_service.visualize_regions(image)
103
+
104
+ # Write the annotated image into an in-memory buffer as PNG
105
+ buf = io.BytesIO()
106
+ annotated.save(buf, format="PNG")
107
+
108
+ # Reset buffer position to the start before streaming
109
+ buf.seek(0)
110
+
111
+ # Stream the PNG bytes directly back to the client
112
+ return StreamingResponse(buf, media_type="image/png")
113
+
114
+ except Exception as e:
115
+ return JSONResponse(status_code=500, content={"error": str(e)})
116
+
117
+
118
 
119
  import uvicorn
120
 
121
  if __name__ == "__main__":
122
+ uvicorn.run("app.main:app", host="0.0.0.0", port=7860)
123
+
app/services/ocr_service.py CHANGED
@@ -9,39 +9,51 @@ import os
9
  class OCRService:
10
 
11
  def __init__(self):
 
12
  if sys.platform.startswith("win"):
13
  pytesseract.pytesseract.tesseract_cmd = os.getenv("TESSERACT_PATH", "C:/Program Files/Tesseract-OCR/tesseract.exe")
 
 
14
  else:
15
  pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
16
 
 
17
  def _preprocess(self, region: Image.Image, scale: int = 3) -> Image.Image:
18
- """Upscale, convert to grayscale, and threshold for better OCR."""
19
  region = region.resize(
20
  (region.width * scale, region.height * scale),
21
  Image.LANCZOS
22
  )
23
- region = region.convert("L") # grayscale
24
- # Increase contrast
 
 
 
25
  region = ImageEnhance.Contrast(region).enhance(2.0)
 
26
  # Threshold to black/white
27
  region = region.point(lambda x: 0 if x < 140 else 255, "1").convert("L")
 
28
  return region
29
 
 
30
  def extract(self, image: Image.Image) -> dict:
 
31
  w, h = image.size
32
 
33
- # Name — skip "Basic Pokemon" line at very top, just grab name row
34
  name_region = image.crop((0.05 * w, 0.06 * h, 0.72 * w, 0.13 * h))
35
 
36
- # HP — top right, large number + "HP" text
37
  hp_region = image.crop((0.55 * w, 0.04 * h, 0.97 * w, 0.13 * h))
38
 
39
- # Moves — middle to lower section
40
  moves_region = image.crop((0.02 * w, 0.52 * h, 0.98 * w, 0.88 * h))
41
 
42
  # Full image for type detection
43
  full_text = pytesseract.image_to_string(image)
44
 
 
45
  return {
46
  "name": self._extract_name(name_region),
47
  "hp": self._extract_hp(hp_region),
@@ -49,44 +61,76 @@ class OCRService:
49
  "moves": self._extract_moves(moves_region),
50
  }
51
 
 
52
  def _extract_name(self, region: Image.Image) -> str | None:
 
53
  region = self._preprocess(region, scale=3)
 
 
54
  text = pytesseract.image_to_string(region, config="--psm 7 --oem 3").strip()
 
55
  # Clean up noise — keep only lines that look like a name
56
  lines = [l.strip() for l in text.splitlines() if l.strip()]
 
 
57
  for line in lines:
58
- # Skip lines that are clearly not a name
59
  if re.search(r'[A-Z][a-z]+', line) and len(line) < 30:
60
  return line
 
61
  return text if text else None
62
 
 
63
  def _extract_hp(self, region: Image.Image) -> str | None:
 
64
  region = self._preprocess(region, scale=3)
 
 
65
  text = pytesseract.image_to_string(region, config="--psm 6 --oem 3")
 
66
  # Look for a number near "HP"
 
67
  match = re.search(r'(\d+)\s*HP|HP\s*(\d+)', text, re.IGNORECASE)
 
68
  if match:
69
  return match.group(1) or match.group(2)
70
- # Fallback: just grab any standalone number (the HP value)
 
71
  match = re.search(r'\b(\d{2,3})\b', text)
 
72
  return match.group(1) if match else None
73
 
 
74
  def _extract_types(self, text: str) -> list[str] | None:
 
75
  types = [
76
  "Fire", "Water", "Grass", "Electric", "Psychic",
77
  "Fighting", "Darkness", "Metal", "Colorless",
78
  "Dragon", "Fairy", "Lightning", "Normal"
79
  ]
 
 
80
  found = [t for t in types if t.lower() in text.lower()]
81
  return found if found else None
82
 
 
83
  def _extract_moves(self, region: Image.Image) -> list[dict] | None:
 
84
  region = self._preprocess(region, scale=2)
 
 
85
  text = pytesseract.image_to_string(region, config="--psm 6 --oem 3")
 
 
86
  lines = [line.strip() for line in text.splitlines() if line.strip()]
87
 
 
88
  moves = []
 
 
89
  i = 0
 
 
90
  while i < len(lines):
91
  # Match: "MoveName 10" or "MoveName 10+" or "MoveName" alone on a line (0 damage moves)
92
  match = re.match(r'^([A-Z][a-zA-Z\s]{2,25}?)\s{2,}(\d+\+?)$', lines[i])
@@ -94,6 +138,7 @@ class OCRService:
94
  # Try looser match for lines like "Psychic 10+"
95
  match = re.match(r'^([A-Z][a-zA-Z]+)\s+(\d+\+?)$', lines[i])
96
 
 
97
  if match:
98
  # Collect any following lines as move description until next move or end
99
  desc_lines = []
@@ -108,7 +153,35 @@ class OCRService:
108
  "text": " ".join(desc_lines) if desc_lines else None
109
  })
110
  i = j
 
 
111
  else:
112
  i += 1
113
 
114
- return moves if moves else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class OCRService:
10
 
11
  def __init__(self):
12
+ # Check if tessaract lives as an environemtn variable
13
  if sys.platform.startswith("win"):
14
  pytesseract.pytesseract.tesseract_cmd = os.getenv("TESSERACT_PATH", "C:/Program Files/Tesseract-OCR/tesseract.exe")
15
+
16
+ # Look for tessaract OCR service in usr/bin folder
17
  else:
18
  pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
19
 
20
+ # Function to preprocess image to desired format
21
  def _preprocess(self, region: Image.Image, scale: int = 3) -> Image.Image:
22
+ # Incrase size of image so that OCR performs better
23
  region = region.resize(
24
  (region.width * scale, region.height * scale),
25
  Image.LANCZOS
26
  )
27
+
28
+ # Convert iamge to grayscale
29
+ region = region.convert("L")
30
+
31
+ # Increase contrast of image
32
  region = ImageEnhance.Contrast(region).enhance(2.0)
33
+
34
  # Threshold to black/white
35
  region = region.point(lambda x: 0 if x < 140 else 255, "1").convert("L")
36
+
37
  return region
38
 
39
+ # Fuction to return all wanted card field texts
40
  def extract(self, image: Image.Image) -> dict:
41
+ # Initialize width and height of image
42
  w, h = image.size
43
 
44
+ # Name field — skip "Basic Pokemon" line at very top, just grab name row
45
  name_region = image.crop((0.05 * w, 0.06 * h, 0.72 * w, 0.13 * h))
46
 
47
+ # HP field — top right, large number + "HP" text
48
  hp_region = image.crop((0.55 * w, 0.04 * h, 0.97 * w, 0.13 * h))
49
 
50
+ # Moves field — middle to lower section
51
  moves_region = image.crop((0.02 * w, 0.52 * h, 0.98 * w, 0.88 * h))
52
 
53
  # Full image for type detection
54
  full_text = pytesseract.image_to_string(image)
55
 
56
+ # Return wanted fields using sepcific field extraction functions
57
  return {
58
  "name": self._extract_name(name_region),
59
  "hp": self._extract_hp(hp_region),
 
61
  "moves": self._extract_moves(moves_region),
62
  }
63
 
64
+ # Function to get pokemon name
65
  def _extract_name(self, region: Image.Image) -> str | None:
66
+ # prepocess region of card
67
  region = self._preprocess(region, scale=3)
68
+
69
+ # Uses PM7 to get text (single line mode)
70
  text = pytesseract.image_to_string(region, config="--psm 7 --oem 3").strip()
71
+
72
  # Clean up noise — keep only lines that look like a name
73
  lines = [l.strip() for l in text.splitlines() if l.strip()]
74
+
75
+ # Loop through all detected lines
76
  for line in lines:
77
+ # Skip lines that are clearly not a name - looks for lines that start with a capital letter followed by lowercase letters and are under 30 characters
78
  if re.search(r'[A-Z][a-z]+', line) and len(line) < 30:
79
  return line
80
+
81
  return text if text else None
82
 
83
+ # Function to get pokemon hp
84
  def _extract_hp(self, region: Image.Image) -> str | None:
85
+ # prepocess region of card
86
  region = self._preprocess(region, scale=3)
87
+
88
+ # Uses PSM 6 (block of text mode)
89
  text = pytesseract.image_to_string(region, config="--psm 6 --oem 3")
90
+
91
  # Look for a number near "HP"
92
+ # First looks for a number adjacent to the letters "HP" in either order
93
  match = re.search(r'(\d+)\s*HP|HP\s*(\d+)', text, re.IGNORECASE)
94
+
95
  if match:
96
  return match.group(1) or match.group(2)
97
+
98
+ # Fallback: grabbing any standalone 2–3 digit number if the "HP" label wasn't recognized
99
  match = re.search(r'\b(\d{2,3})\b', text)
100
+
101
  return match.group(1) if match else None
102
 
103
+ # Function to get pokemon types
104
  def _extract_types(self, text: str) -> list[str] | None:
105
+ # All pokemon types for keywrods extraction
106
  types = [
107
  "Fire", "Water", "Grass", "Electric", "Psychic",
108
  "Fighting", "Darkness", "Metal", "Colorless",
109
  "Dragon", "Fairy", "Lightning", "Normal"
110
  ]
111
+
112
+ # Checks whether any known Pokémon type name appears anywhere in the OCR output
113
  found = [t for t in types if t.lower() in text.lower()]
114
  return found if found else None
115
 
116
+ # Function to get pokemon moves
117
  def _extract_moves(self, region: Image.Image) -> list[dict] | None:
118
+ # prepocess region of card
119
  region = self._preprocess(region, scale=2)
120
+
121
+ # Uses PSM 6 (block of text mode)
122
  text = pytesseract.image_to_string(region, config="--psm 6 --oem 3")
123
+
124
+ # Returns all dtected lines
125
  lines = [line.strip() for line in text.splitlines() if line.strip()]
126
 
127
+ # Empty list to store detected moves
128
  moves = []
129
+
130
+ # Index starting at 0
131
  i = 0
132
+
133
+ # Loop through all dtected lines
134
  while i < len(lines):
135
  # Match: "MoveName 10" or "MoveName 10+" or "MoveName" alone on a line (0 damage moves)
136
  match = re.match(r'^([A-Z][a-zA-Z\s]{2,25}?)\s{2,}(\d+\+?)$', lines[i])
 
138
  # Try looser match for lines like "Psychic 10+"
139
  match = re.match(r'^([A-Z][a-zA-Z]+)\s+(\d+\+?)$', lines[i])
140
 
141
+ # Separate move into names, damage, and etxt if a move is detected
142
  if match:
143
  # Collect any following lines as move description until next move or end
144
  desc_lines = []
 
153
  "text": " ".join(desc_lines) if desc_lines else None
154
  })
155
  i = j
156
+
157
+ # Increase index to inspect to the next move
158
  else:
159
  i += 1
160
 
161
+ return moves if moves else None
162
+
163
+
164
+ def visualize_regions(self, image: Image.Image) -> Image.Image:
165
+ """Draw colored boxes over each OCR crop region and return the annotated image."""
166
+ from PIL import ImageDraw, ImageFont
167
+
168
+ w, h = image.size
169
+ vis = image.copy()
170
+ draw = ImageDraw.Draw(vis)
171
+
172
+ regions = {
173
+ "Name": (0.05 * w, 0.06 * h, 0.72 * w, 0.13 * h),
174
+ "HP": (0.55 * w, 0.04 * h, 0.97 * w, 0.13 * h),
175
+ "Moves": (0.02 * w, 0.52 * h, 0.98 * w, 0.88 * h),
176
+ }
177
+ colors = {
178
+ "Name": "red",
179
+ "HP": "blue",
180
+ "Moves": "green",
181
+ }
182
+
183
+ for label, box in regions.items():
184
+ draw.rectangle(box, outline=colors[label], width=3)
185
+ draw.text((box[0] + 4, box[1] + 4), label, fill=colors[label])
186
+
187
+ return vis