Spaces:
Runtime error
Runtime error
from fastapi import FastAPI | |
from pydantic import BaseModel | |
import torch | |
import torch.nn as nn | |
import torch | |
from torchvision import transforms | |
from typing import Any, Type | |
import pydantic | |
import torch | |
class TensorSchema(pydantic.BaseModel): | |
shape: list[int] | |
dtype: str | |
requires_grad: bool | |
def __get_pydantic_core_schema__(cls: Type[Any]) -> pydantic.schema.Schema: | |
return pydantic.schema.Schema( | |
type="object", | |
properties={ | |
"shape": pydantic.schema.Schema(type="array", items=pydantic.schema.Schema(type="integer")), | |
"dtype": pydantic.schema.Schema(type="string"), | |
"requires_grad": pydantic.schema.Schema(type="boolean"), | |
}, | |
required=["shape", "dtype", "requires_grad"], | |
) | |
class TorchTensor(torch.Tensor): | |
def __get_pydantic_core_schema__(cls: Type[Any]) -> pydantic.schema.Schema: | |
return TensorSchema.__get_pydantic_core_schema__() | |
class Prediction(BaseModel): | |
prediction: torch.Tensor | |
app = FastAPI() | |
# Load the PyTorch model | |
model = torch.load("best_model-epoch=01-val_loss=3.00.ckpt") | |
# Define a function to preprocess the input | |
def preprocess_input(input): | |
"""Preprocess the input image for the PyTorch image classification model. | |
Args: | |
input: A PIL Image object. | |
Returns: | |
A PyTorch tensor representing the preprocessed image. | |
""" | |
# Resize the image to the expected size. | |
input = input.resize((224, 224)) | |
# Convert the image to a PyTorch tensor. | |
input = torch.from_numpy(np.array(input)).float() | |
# Normalize the image. | |
input = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input) | |
# Return the preprocessed image. | |
return input | |
async def predict_endpoint(input: fastapi.File): | |
"""Predict the output of the PyTorch image classification model. | |
Args: | |
input: A file containing the input image. | |
Returns: | |
A JSON object containing the prediction. | |
""" | |
# Load the image. | |
image = await input.read() | |
image = Image.open(BytesIO(image)) | |
# Preprocess the image. | |
image = preprocess_input(image) | |
# Make a prediction. | |
prediction = model(image.unsqueeze(0)) | |
# Get the top predicted class. | |
predicted_class = prediction.argmax(1) | |
# Return the prediction. | |
return Prediction(prediction=predicted_class) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |