Cloud_AI_model / app.py
Tom
fixed a path issue
e39c55a
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import torchvision.models as models
import torch.nn.functional as F
import gradio as gr
# Define the VisionTransformer model class
class VisionTransformer(nn.Module):
def __init__(self, num_classes):
super(VisionTransformer, self).__init__()
self.model = models.vit_b_16(weights=None) # Initialize without weights
self.model.heads.head = nn.Linear(self.model.heads.head.in_features, num_classes)
def forward(self, X):
return self.model(X)
# Function to load the model
def load_model(model_path, num_classes):
model = VisionTransformer(num_classes=num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval() # Set the model to evaluation mode
return model
# Preprocess the input image
def preprocess_image(image_path):
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert("RGB")
return transform(image).unsqueeze(0) # Add batch dimension
# Dictionary mapping short names to full cloud names
cloud_name_mapping = {
'AC': 'Altocumulus',
'As': 'Altostratus',
'Cb': 'Cumulonimbus',
'Cc': 'Cirrocumulus',
'Ci': 'Cirrus',
'Cs': 'Cirrostratus',
'Ct': 'Contrails',
'Cu': 'Cumulus',
'Ns': 'Nimbostratus',
'Sc': 'Stratocumulus',
'St': 'Stratus'
}
# Function to make a prediction
def predict(image_path, model, class_names):
image_tensor = preprocess_image(image_path)
with torch.no_grad():
outputs = model(image_tensor)
probabilities = F.softmax(outputs, dim=1)
_, predicted = torch.max(outputs, 1)
confidence = probabilities[0][predicted.item()].item()
predicted_class = class_names[predicted.item()]
full_name = cloud_name_mapping.get(predicted_class, "Unknown")
return full_name, confidence
# Load the model
model_path = "VisionTransformer_with_crop_final_model.pth" # Replace with your model path
class_names = ['AC','As','Cb','Cc','Ci','Cs','Ct','Cu','Ns','Sc','St'] # Replace with your actual class names
model = load_model(model_path, num_classes=len(class_names))
# Gradio interface
def classify_image(image_path):
predicted_class, confidence = predict(image_path, model, class_names)
return f"Prediction: {predicted_class}\nConfidence: {confidence:.2f}"
gr.Interface(
fn=classify_image,
inputs=gr.Image(type="filepath"),
outputs="text",
title="Vision Transformer Image Classification"
).launch()