geoNet / app.py
RolfeD11's picture
Update app.py
73ff029 verified
import streamlit as st
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import transformers
from transformers import ViTModel, ViTFeatureExtractor, ViTImageProcessor, ViTForImageClassification
from PIL import Image
import matplotlib.pyplot as plt
class geoNet(nn.Module):
def __init__(self):
super(geoNet, self).__init__()
self.name = "geo"
self.fc1 = nn.Linear(768, 512)
self.bn1 = nn.BatchNorm1d(512)
self.fc2 = nn.Linear(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.fc3 = nn.Linear(256, 128)
self.bn3 = nn.BatchNorm1d(128)
self.fc4 = nn.Linear(128, 64)
self.bn4 = nn.BatchNorm1d(64)
self.classifier = nn.Linear(64, 63) # Classification head
self.regressor = nn.Linear(64, 2) # Regression head for (latitude, longitude)
def forward(self, x):
x = x.view(x.size(0), -1)
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.fc2(x)))
x = F.relu(self.bn3(self.fc3(x)))
x = F.relu(self.bn4(self.fc4(x)))
province_pred = self.classifier(x)
coords_pred = self.regressor(x)
return province_pred, coords_pred
model = geoNet()
model_path = 'geo_53.07_15.52'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# Load pretrained ViT model for feature extraction
model_name = 'google/vit-base-patch16-224-in21k'
vit = ViTModel.from_pretrained(model_name, attn_implementation="eager")
processor = ViTImageProcessor.from_pretrained(model_name)
vit.eval()
# Define preprocessing function
def crop_bottom(img):
width, height = img.size
return img.crop((0, 0, width, height - 18)) # get rid of author label
preprocess = transforms.Compose([
transforms.Lambda(crop_bottom),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load and preprocess the image
def load_image(image_path):
image = Image.open(image_path).convert('RGB')
image = preprocess(image)
image = image.unsqueeze(0) # Add batch dimension
return image
def imshow(img):
fig, ax = plt.subplots()
img = img / 2 + 0.5 # Denormalize the image
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.axis('off') # Turn off axis labels
st.pyplot(fig)
# Streamlit app
st.title("Image Geolocation Prediction")
# Upload image
uploaded_file = st.file_uploader("Choose an image...", type="jpg")
if uploaded_file is not None:
# Save uploaded image to disk
image_path = "uploaded_image.jpg"
with open(image_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# Load and show the image
image = load_image(image_path)
st.image(image_path, caption='Uploaded Image', use_column_width=True)
image_to_show = image.squeeze(0)
imshow(image_to_show)
# Move image to GPU if available
if torch.cuda.is_available():
image = image.cuda()
# Extract features
image = Image.open(image_path).convert('RGB')
inputs = processor(images=image, return_tensors="pt")
outputs = vit(**inputs, output_attentions = True)
last_hidden_states = outputs.last_hidden_state
cls_hidden_state = last_hidden_states[:, 0, :]
with torch.no_grad():
province_preds, coord_preds = model(cls_hidden_state)
_, predicted_province = torch.max(province_preds, 1)
predicted_coords = coord_preds.cpu().numpy()
# Load class to index mapping
class_to_idx = {'Alabama': 0, 'Alaska': 1, 'Alberta': 2, 'Arizona': 3, 'Arkansas': 4, 'British Columbia': 5, 'California': 6, 'Colorado': 7, 'Connecticut': 8, 'Delaware': 9, 'Florida': 10, 'Georgia': 11, 'Hawaii': 12, 'Idaho': 13, 'Illinois': 14, 'Indiana': 15, 'Iowa': 16, 'Kansas': 17, 'Kentucky': 18, 'Louisiana': 19, 'Maine': 20, 'Manitoba': 21, 'Maryland': 22, 'Massachusetts': 23, 'Michigan': 24, 'Minnesota': 25, 'Mississippi': 26, 'Missouri': 27, 'Montana': 28, 'Nebraska': 29, 'Nevada': 30, 'New Brunswick': 31, 'New Hampshire': 32, 'New Jersey': 33, 'New Mexico': 34, 'New York': 35, 'Newfoundland and Labrador': 36, 'North Carolina': 37, 'North Dakota': 38, 'Northwest Territories': 39, 'Nova Scotia': 40, 'Nunavut': 41, 'Ohio': 42, 'Oklahoma': 43, 'Ontario': 44, 'Oregon': 45, 'Pennsylvania': 46, 'Prince Edward Island': 47, 'Quebec': 48, 'Rhode Island': 49, 'Saskatchewan': 50, 'South Carolina': 51, 'South Dakota': 52, 'Tennessee': 53, 'Texas': 54, 'Utah': 55, 'Vermont': 56, 'Virginia': 57, 'Washington': 58, 'West Virginia': 59, 'Wisconsin': 60, 'Wyoming': 61, 'Yukon': 62}
idx_to_class = {idx: class_name for class_name, idx in class_to_idx.items()}
# Display predictions with increased font size
st.markdown(
f"<h3 style='font-size:20px;'>Predicted Province/State Index: {idx_to_class.get(predicted_province.item(), None)}</h3>",
unsafe_allow_html=True
)
st.markdown(
f"<h3 style='font-size:20px;'>Predicted Coordinates: {predicted_coords}</h3>",
unsafe_allow_html=True
)
else:
st.write("Please upload an image to get predictions.")