dcorcoran commited on
Commit
6ce2877
·
1 Parent(s): e11f0c7

Added length and weight extraction

Browse files
Files changed (3) hide show
  1. app/main.py +2 -0
  2. app/schemas.py +2 -0
  3. app/services/ocr_service.py +75 -7
app/main.py CHANGED
@@ -72,6 +72,8 @@ async def predict(file: UploadFile = File(...)):
72
  hp=ocr_data.get("hp"),
73
  types=ocr_data.get("types"),
74
  moves=ocr_data.get("moves"),
 
 
75
  similar_cards=similar_cards
76
  )
77
  except Exception as e:
 
72
  hp=ocr_data.get("hp"),
73
  types=ocr_data.get("types"),
74
  moves=ocr_data.get("moves"),
75
+ length=ocr_data.get("length"),
76
+ weight=ocr_data.get("weight"),
77
  similar_cards=similar_cards
78
  )
79
  except Exception as e:
app/schemas.py CHANGED
@@ -36,4 +36,6 @@ class CardResponse(BaseModel):
36
  hp: Optional[str] = None
37
  types: Optional[list[str]] = None
38
  moves: Optional[list[Move]] = None
 
 
39
  similar_cards: list[SimilarCard] = []
 
36
  hp: Optional[str] = None
37
  types: Optional[list[str]] = None
38
  moves: Optional[list[Move]] = None
39
+ length: Optional[str] = None
40
+ weight: Optional[str] = None
41
  similar_cards: list[SimilarCard] = []
app/services/ocr_service.py CHANGED
@@ -35,18 +35,47 @@ class OCRService:
35
  region = region.point(lambda x: 0 if x < 140 else 255, "1").convert("L")
36
 
37
  return region
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # Fuction to return all wanted card field texts
40
  def extract(self, image: Image.Image) -> dict:
41
  # Initialize width and height of image
42
  w, h = image.size
43
 
 
 
 
44
  # Name field — skip "Basic Pokemon" line at very top, just grab name row
45
- name_region = image.crop((0.05 * w, 0.06 * h, 0.72 * w, 0.13 * h))
 
 
 
 
46
 
47
  # HP field — top right, large number + "HP" text
48
  hp_region = image.crop((0.55 * w, 0.04 * h, 0.97 * w, 0.13 * h))
49
 
 
 
 
50
  # Moves field — middle to lower section
51
  moves_region = image.crop((0.02 * w, 0.52 * h, 0.98 * w, 0.88 * h))
52
 
@@ -59,8 +88,11 @@ class OCRService:
59
  "hp": self._extract_hp(hp_region),
60
  "types": self._extract_types(full_text),
61
  "moves": self._extract_moves(moves_region),
 
 
62
  }
63
 
 
64
  # Function to get pokemon name
65
  def _extract_name(self, region: Image.Image) -> str | None:
66
  # prepocess region of card
@@ -80,6 +112,7 @@ class OCRService:
80
 
81
  return text if text else None
82
 
 
83
  # Function to get pokemon hp
84
  def _extract_hp(self, region: Image.Image) -> str | None:
85
  # prepocess region of card
@@ -100,6 +133,7 @@ class OCRService:
100
 
101
  return match.group(1) if match else None
102
 
 
103
  # Function to get pokemon types
104
  def _extract_types(self, text: str) -> list[str] | None:
105
  # All pokemon types for keywrods extraction
@@ -113,6 +147,37 @@ class OCRService:
113
  found = [t for t in types if t.lower() in text.lower()]
114
  return found if found else None
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  # Function to get pokemon moves
117
  def _extract_moves(self, region: Image.Image) -> list[dict] | None:
118
  # prepocess region of card
@@ -166,18 +231,21 @@ class OCRService:
166
  from PIL import ImageDraw, ImageFont
167
 
168
  w, h = image.size
 
169
  vis = image.copy()
170
  draw = ImageDraw.Draw(vis)
171
 
172
  regions = {
173
- "Name": (0.05 * w, 0.06 * h, 0.72 * w, 0.13 * h),
174
- "HP": (0.55 * w, 0.04 * h, 0.97 * w, 0.13 * h),
175
- "Moves": (0.02 * w, 0.52 * h, 0.98 * w, 0.88 * h),
 
176
  }
177
  colors = {
178
- "Name": "red",
179
- "HP": "blue",
180
- "Moves": "green",
 
181
  }
182
 
183
  for label, box in regions.items():
 
35
  region = region.point(lambda x: 0 if x < 140 else 255, "1").convert("L")
36
 
37
  return region
38
+
39
+
40
+ # Function to check top-left corner for 'STAGE' text indicating an evolved Pokémon.
41
+ def _is_evolved(self, image: Image.Image) -> bool:
42
+ # Get size of the image
43
+ w, h = image.size
44
+
45
+ # Crop to the top left of the card
46
+ top_left = image.crop((0, 0, 0.35 * w, 0.12 * h))
47
+
48
+ # Preprocesss the cropped region
49
+ top_left = self._preprocess(top_left, scale=3)
50
+
51
+ # Uses PSM 6 (block of text mode)
52
+ text = pytesseract.image_to_string(top_left, config="--psm 6 --oem 3")
53
+
54
+ # Return boolean value depending on if the words STAGE1 or STAGE 2 appear
55
+ return bool(re.search(r'stage\s*[12]', text, re.IGNORECASE))
56
+
57
 
58
  # Fuction to return all wanted card field texts
59
  def extract(self, image: Image.Image) -> dict:
60
  # Initialize width and height of image
61
  w, h = image.size
62
 
63
+ # Detect whether this is an evolved card (Stage 1 or Stage 2)
64
+ evolved = self._is_evolved(image)
65
+
66
  # Name field — skip "Basic Pokemon" line at very top, just grab name row
67
+ # Evolved cards have an evolution picture in the top-left (~0–25% width), so the name starts further right. Basic cards start near the left edge.
68
+ if evolved:
69
+ name_region = image.crop((0.28 * w, 0.06 * h, 0.72 * w, 0.13 * h))
70
+ else:
71
+ name_region = image.crop((0.05 * w, 0.06 * h, 0.72 * w, 0.13 * h))
72
 
73
  # HP field — top right, large number + "HP" text
74
  hp_region = image.crop((0.55 * w, 0.04 * h, 0.97 * w, 0.13 * h))
75
 
76
+ # Length/Weight — bar sits just above the moves section
77
+ length_weight_region = image.crop((0.05 * w, 0.52 * h, 0.95 * w, 0.58 * h))
78
+
79
  # Moves field — middle to lower section
80
  moves_region = image.crop((0.02 * w, 0.52 * h, 0.98 * w, 0.88 * h))
81
 
 
88
  "hp": self._extract_hp(hp_region),
89
  "types": self._extract_types(full_text),
90
  "moves": self._extract_moves(moves_region),
91
+ "length": self._extract_length(length_weight_region),
92
+ "weight": self._extract_weight(length_weight_region),
93
  }
94
 
95
+
96
  # Function to get pokemon name
97
  def _extract_name(self, region: Image.Image) -> str | None:
98
  # prepocess region of card
 
112
 
113
  return text if text else None
114
 
115
+
116
  # Function to get pokemon hp
117
  def _extract_hp(self, region: Image.Image) -> str | None:
118
  # prepocess region of card
 
133
 
134
  return match.group(1) if match else None
135
 
136
+
137
  # Function to get pokemon types
138
  def _extract_types(self, text: str) -> list[str] | None:
139
  # All pokemon types for keywrods extraction
 
147
  found = [t for t in types if t.lower() in text.lower()]
148
  return found if found else None
149
 
150
+
151
+ # Extract length
152
+ def _extract_length(self, region: Image.Image) -> str | None:
153
+ # Preprocess the length region
154
+ region = self._preprocess(region, scale=3)
155
+
156
+ # Uses PSM 6 (block of text mode)
157
+ text = pytesseract.image_to_string(region, config="--psm 6 --oem 3")
158
+
159
+ # Find length match through regex patterns (after Length)
160
+ match = re.search(r"Length[:\s]+([0-9]+'\s*[0-9]+\")", text, re.IGNORECASE)
161
+
162
+ # Return a match or None
163
+ return match.group(1).strip() if match else None
164
+
165
+
166
+ # Extract weight
167
+ def _extract_weight(self, region: Image.Image) -> str | None:
168
+ # Preprocess the width region
169
+ region = self._preprocess(region, scale=3)
170
+
171
+ # Uses PSM 6 (block of text mode)
172
+ text = pytesseract.image_to_string(region, config="--psm 6 --oem 3")
173
+
174
+ # Find weight match through regex patterns (bewteen Weight and lbs)
175
+ match = re.search(r"Weight[:\s]+([\d.]+\s*lbs?\.?)", text, re.IGNORECASE)
176
+
177
+ # Return a match or None
178
+ return match.group(1).strip() if match else None
179
+
180
+
181
  # Function to get pokemon moves
182
  def _extract_moves(self, region: Image.Image) -> list[dict] | None:
183
  # prepocess region of card
 
231
  from PIL import ImageDraw, ImageFont
232
 
233
  w, h = image.size
234
+ evolved = self._is_evolved(image)
235
  vis = image.copy()
236
  draw = ImageDraw.Draw(vis)
237
 
238
  regions = {
239
+ "Name": (0.28 * w if evolved else 0.05 * w, 0.06 * h, 0.72 * w, 0.13 * h),
240
+ "HP": (0.55 * w, 0.04 * h, 0.97 * w, 0.13 * h),
241
+ "Length/Weight": (0.05 * w, 0.52 * h, 0.95 * w, 0.58 * h),
242
+ "Moves": (0.02 * w, 0.57 * h, 0.98 * w, 0.88 * h),
243
  }
244
  colors = {
245
+ "Name": "red",
246
+ "HP": "blue",
247
+ "Length/Weight": "orange",
248
+ "Moves": "green",
249
  }
250
 
251
  for label, box in regions.items():