Anwarkh1's picture
Update main.py
9adb9c8 verified
raw
history blame
No virus
1.65 kB
import os
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" # Set cache directory to a writable location
from fastapi import FastAPI, UploadFile, File
from transformers import ViTForImageClassification, ViTFeatureExtractor, AutoTokenizer
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import io
app = FastAPI()
# Load the ViT model and its feature extractor
model_name = "Anwarkh1/Skin_Cancer-Image_Classification"
model = ViTForImageClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Define class labels
class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']
# Define image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Define API endpoint for model inference
@app.post('/predict')
async def predict(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents))
image = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
outputs = model(image)
# Calculate softmax probabilities
probabilities = torch.softmax(outputs.logits, dim=1)
# Get predicted class index and its probability
predicted_idx = torch.argmax(probabilities).item()
predicted_label = class_labels[predicted_idx]
predicted_accuracy = probabilities[0][predicted_idx].item()
return {'predicted_class': predicted_label, 'accuracy': predicted_accuracy}