dcorcoran commited on
Commit
306d509
·
1 Parent(s): 428d984

Changed image upload technique

Browse files
Files changed (2) hide show
  1. app/main.py +5 -7
  2. app/schemas.py +10 -1
app/main.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, File, UploadFile
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
 
@@ -9,10 +9,9 @@ import io
9
  from app.services.embedding_service import EmbeddingService
10
  from app.services.similarity_service import SimilarityService
11
  from app.services.ocr_service import OCRService
12
- from app.schemas import CardResponse
13
 
14
- import requests
15
- import os
16
 
17
 
18
  # ---------------
@@ -51,14 +50,13 @@ def health():
51
 
52
 
53
  @app.post("/predict", response_model=CardResponse)
54
- async def predict(file: UploadFile = File(...)):
55
  try:
56
- image_bytes = await file.read()
57
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
58
 
59
  embedding = app.state.embedding_service.embed(image)
60
  similar_cards = app.state.similarity_service.search(embedding, top_k=5)
61
-
62
  ocr_data = app.state.ocr_service.extract(image)
63
 
64
  if similar_cards and similar_cards[0]["score"] > 0.99:
 
1
+ from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
 
 
9
  from app.services.embedding_service import EmbeddingService
10
  from app.services.similarity_service import SimilarityService
11
  from app.services.ocr_service import OCRService
12
+ from app.schemas import CardResponse, ImagePayload
13
 
14
+ import base64
 
15
 
16
 
17
  # ---------------
 
50
 
51
 
52
  @app.post("/predict", response_model=CardResponse)
53
+ async def predict(payload: ImagePayload):
54
  try:
55
+ image_bytes = base64.b64decode(payload.image_b64) # <-- decode base64
56
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
57
 
58
  embedding = app.state.embedding_service.embed(image)
59
  similar_cards = app.state.similarity_service.search(embedding, top_k=5)
 
60
  ocr_data = app.state.ocr_service.extract(image)
61
 
62
  if similar_cards and similar_cards[0]["score"] > 0.99:
app/schemas.py CHANGED
@@ -36,4 +36,13 @@ class CardResponse(BaseModel):
36
  hp: Optional[str] = None
37
  types: Optional[list[str]] = None
38
  moves: Optional[list[Move]] = None
39
- similar_cards: list[SimilarCard] = []
 
 
 
 
 
 
 
 
 
 
36
  hp: Optional[str] = None
37
  types: Optional[list[str]] = None
38
  moves: Optional[list[Move]] = None
39
+ similar_cards: list[SimilarCard] = []
40
+
41
+
42
+ # --------------------------------
43
+ # ----- IMAGE PAYLOAD ------------
44
+ # --------------------------------
45
+
46
+ class ImagePayload(BaseModel):
47
+ image_b64: str
48
+ filename: str