|
import streamlit as st |
|
import torch.nn as nn |
|
|
|
import torch |
|
from torchvision import models, transforms |
|
from PIL import Image |
|
|
|
CATEGORIES = ["AIHOLE", "BILLESHWAR_TEMPLE", "CHENNAKESHWARA_TEMPLE", "HAMPI_CHARIOT", "IBRAHIM_ROZA", "JAIN_BASADI", "KAMAL_BASTI", "KEDARESHWARA_TEMPLE", "KESHAVA_TEMPLE", "LOTUS_MAHAL"] |
|
IMG_SIZE = 224 |
|
|
|
model = models.resnet50(pretrained=False) |
|
num_features = model.fc.in_features |
|
model.fc = nn.Linear(num_features, len(CATEGORIES)) |
|
model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((IMG_SIZE, IMG_SIZE)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
def classify_image(image): |
|
image = transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(image) |
|
_, predicted = torch.max(outputs.data, 1) |
|
|
|
return predicted.item() |
|
|
|
|
|
def main(): |
|
st.title("Temple Image Classification") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file) |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
if st.button("Classify"): |
|
prediction = classify_image(image) |
|
st.write(f"Predicted Category: {CATEGORIES[prediction]}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|