# main.py import json from fastapi import FastAPI, File, UploadFile, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.templating import Jinja2Templates from PIL import Image from io import BytesIO import torch import torch.nn as nn import torchvision.transforms as transforms from torchvision import models class FruitRecognizer: def __init__(self, model_path, num_classes): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = models.resnet18(pretrained=False) self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) self.model.load_state_dict(torch.load(model_path, map_location=self.device)) self.model.to(self.device) self.model.eval() self.transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def recognize_fruit_from_path(self, image_path, class_names): img = Image.open(image_path).convert("RGB") img = self.transform(img) img = img.unsqueeze(0).to(self.device) with torch.no_grad(): outputs = self.model(img) _, predicted = torch.max(outputs.data, 1) predicted_class = class_names[predicted.item()] return predicted_class def recognize_fruit(self, image, class_names): img = self.transform(image) img = img.unsqueeze(0).to(self.device) with torch.no_grad(): outputs = self.model(img) _, predicted = torch.max(outputs.data, 1) predicted_class = class_names[predicted.item()] return predicted_class app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) with open('metadata.json', 'r') as f: metadata = json.load(f) class_names = metadata['classes'] model_path = "models/fruit_recognition_model.pth" recognizer = FruitRecognizer(model_path, len(class_names)) templates = Jinja2Templates(directory="templates") @app.get("/", response_class=HTMLResponse) async def root(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/predict/") async def predict_fruit(request: Request, file: UploadFile = File(...)): try: print("request") img = Image.open(BytesIO(await file.read())).convert("RGB") predicted_class = recognizer.recognize_fruit(img, class_names) return JSONResponse({"predicted_class": predicted_class}) except Exception as e: return JSONResponse({"error": str(e)}, status_code=400) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="localhost", port=7860)