Kortikov Mikhail
fix
a940bc4
raw
history blame
No virus
2.95 kB
# main.py
import json
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from PIL import Image
from io import BytesIO
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
class FruitRecognizer:
def __init__(self, model_path, num_classes):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = models.resnet18(pretrained=False)
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.to(self.device)
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def recognize_fruit_from_path(self, image_path, class_names):
img = Image.open(image_path).convert("RGB")
img = self.transform(img)
img = img.unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(img)
_, predicted = torch.max(outputs.data, 1)
predicted_class = class_names[predicted.item()]
return predicted_class
def recognize_fruit(self, image, class_names):
img = self.transform(image)
img = img.unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(img)
_, predicted = torch.max(outputs.data, 1)
predicted_class = class_names[predicted.item()]
return predicted_class
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
with open('metadata.json', 'r') as f:
metadata = json.load(f)
class_names = metadata['classes']
model_path = "models/fruit_recognition_model.pth"
recognizer = FruitRecognizer(model_path, len(class_names))
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/predict/")
async def predict_fruit(request: Request, file: UploadFile = File(...)):
try:
print("request")
img = Image.open(BytesIO(await file.read())).convert("RGB")
predicted_class = recognizer.recognize_fruit(img, class_names)
return JSONResponse({"predicted_class": predicted_class})
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=400)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="localhost", port=7860)