Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from resnet import Resnet50Flower102 | |
import pandas as pd | |
st.title("Flower Image Classification") | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = Resnet50Flower102(device) | |
flowers_data = pd.read_csv("flowerdata.csv") | |
uploaded_file=st.file_uploader("Choose your file", type=["jpg", "png", "jpeg"]) | |
model.load_state_dict(torch.load("model.pth", map_location=torch.device(device))) | |
transform_val = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
img = transform_val(image) | |
img = img.type(torch.FloatTensor).to(device) | |
print(img.shape) | |
img = img.unsqueeze(0) | |
print(img.shape) | |
with torch.no_grad(): | |
model.eval() | |
flower = model(img) | |
_, flower = flower.max(1) | |
flower = flower[0].detach().cpu().numpy() | |
flower_name = flowers_data["Name"][flower] | |
st.header("Input Image") | |
st.image(image=image, use_column_width=True) | |
st.write("##", flower_name) |