gauravtripathy commited on
Commit
34733d6
·
1 Parent(s): 3df8c85

Upload 2 files

Browse files

Historical places detection for ten class of temples

Files changed (2) hide show
  1. ancientdetection.py +52 -0
  2. trained_model.pt +3 -0
ancientdetection.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch.nn as nn
3
+
4
+ import torch
5
+ from torchvision import models, transforms
6
+ from PIL import Image
7
+
8
+ CATEGORIES = ["AIHOLE", "BILLESHWAR_TEMPLE", "CHENNAKESHWARA_TEMPLE", "HAMPI_CHARIOT", "IBRAHIM_ROZA", "JAIN_BASADI", "KAMAL_BASTI", "KEDARESHWARA_TEMPLE", "KESHAVA_TEMPLE", "LOTUS_MAHAL"]
9
+ IMG_SIZE = 224
10
+ # Load the trained model
11
+ model = models.resnet50(pretrained=False)
12
+ num_features = model.fc.in_features
13
+ model.fc = nn.Linear(num_features, len(CATEGORIES))
14
+ model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device('cpu')))
15
+ model.eval()
16
+
17
+ # Define the image transform
18
+ transform = transforms.Compose([
19
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
22
+ ])
23
+
24
+ # Define the prediction function
25
+ def classify_image(image):
26
+ image = transform(image).unsqueeze(0)
27
+
28
+ # Make prediction
29
+ with torch.no_grad():
30
+ outputs = model(image)
31
+ _, predicted = torch.max(outputs.data, 1)
32
+
33
+ return predicted.item()
34
+
35
+ # Streamlit app
36
+ def main():
37
+ st.title("Temple Image Classification")
38
+
39
+ # File uploader
40
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
41
+
42
+ if uploaded_file is not None:
43
+ image = Image.open(uploaded_file)
44
+ st.image(image, caption="Uploaded Image", use_column_width=True)
45
+
46
+ # Classify image on button click
47
+ if st.button("Classify"):
48
+ prediction = classify_image(image)
49
+ st.write(f"Predicted Category: {CATEGORIES[prediction]}")
50
+
51
+ if __name__ == "__main__":
52
+ main()
trained_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efa64c02a8e1261be97bd81c1567faba65d94c7a30c61083e92b77b42e0df3c5
3
+ size 94433165