Changed image upload technique
Browse files- app/main.py +5 -7
- app/schemas.py +10 -1
app/main.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from fastapi import FastAPI
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
from fastapi.responses import JSONResponse
|
| 4 |
|
|
@@ -9,10 +9,9 @@ import io
|
|
| 9 |
from app.services.embedding_service import EmbeddingService
|
| 10 |
from app.services.similarity_service import SimilarityService
|
| 11 |
from app.services.ocr_service import OCRService
|
| 12 |
-
from app.schemas import CardResponse
|
| 13 |
|
| 14 |
-
import
|
| 15 |
-
import os
|
| 16 |
|
| 17 |
|
| 18 |
# ---------------
|
|
@@ -51,14 +50,13 @@ def health():
|
|
| 51 |
|
| 52 |
|
| 53 |
@app.post("/predict", response_model=CardResponse)
|
| 54 |
-
async def predict(
|
| 55 |
try:
|
| 56 |
-
image_bytes =
|
| 57 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 58 |
|
| 59 |
embedding = app.state.embedding_service.embed(image)
|
| 60 |
similar_cards = app.state.similarity_service.search(embedding, top_k=5)
|
| 61 |
-
|
| 62 |
ocr_data = app.state.ocr_service.extract(image)
|
| 63 |
|
| 64 |
if similar_cards and similar_cards[0]["score"] > 0.99:
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
from fastapi.responses import JSONResponse
|
| 4 |
|
|
|
|
| 9 |
from app.services.embedding_service import EmbeddingService
|
| 10 |
from app.services.similarity_service import SimilarityService
|
| 11 |
from app.services.ocr_service import OCRService
|
| 12 |
+
from app.schemas import CardResponse, ImagePayload
|
| 13 |
|
| 14 |
+
import base64
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
# ---------------
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
@app.post("/predict", response_model=CardResponse)
|
| 53 |
+
async def predict(payload: ImagePayload):
|
| 54 |
try:
|
| 55 |
+
image_bytes = base64.b64decode(payload.image_b64) # <-- decode base64
|
| 56 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 57 |
|
| 58 |
embedding = app.state.embedding_service.embed(image)
|
| 59 |
similar_cards = app.state.similarity_service.search(embedding, top_k=5)
|
|
|
|
| 60 |
ocr_data = app.state.ocr_service.extract(image)
|
| 61 |
|
| 62 |
if similar_cards and similar_cards[0]["score"] > 0.99:
|
app/schemas.py
CHANGED
|
@@ -36,4 +36,13 @@ class CardResponse(BaseModel):
|
|
| 36 |
hp: Optional[str] = None
|
| 37 |
types: Optional[list[str]] = None
|
| 38 |
moves: Optional[list[Move]] = None
|
| 39 |
-
similar_cards: list[SimilarCard] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
hp: Optional[str] = None
|
| 37 |
types: Optional[list[str]] = None
|
| 38 |
moves: Optional[list[Move]] = None
|
| 39 |
+
similar_cards: list[SimilarCard] = []
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# --------------------------------
|
| 43 |
+
# ----- IMAGE PAYLOAD ------------
|
| 44 |
+
# --------------------------------
|
| 45 |
+
|
| 46 |
+
class ImagePayload(BaseModel):
|
| 47 |
+
image_b64: str
|
| 48 |
+
filename: str
|