jervinjosh68's picture
added sample image and changed app.py
8980024
from model import AQC_NET
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
import numpy as np
import gradio as gr
import requests
import os
def get_file(url,path,filename, chunk_size=128):
r = requests.get(url, stream=True)
with open(path, 'wb') as downloaded:
for chunk in r.iter_content(chunk_size=chunk_size):
downloaded.write(chunk)
def predict(image_name):
labels = {0:"1-20", 1: "21-40" , 2: "41 and above"}
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = preprocess(image_name)
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(inputs.unsqueeze(0))
values, indices = torch.topk(outputs, k=3)
print(values,indices)
return {labels[i.item()]: v.item() for i, v in zip(indices[0], values.detach()[0])}
def preprocess(image_name):
transforms = T.Compose([
T.Resize((256,256)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = transforms(image_name)
return image
def run_gradio():
title = "AQC_NET PH - EEE199 Student Project"
description = "AQC_NET PH is an image-based deep learning model finetuned on a data-set created in the National Capital Region of Philippines using a Nova PM SDS011 Sensor"
examples = ["test_img.jpg","test_img2.jpg","test_img3.jpg"]
inputs = [
gr.inputs.Image(type="pil", label="Input Image")
]
gr.Interface(
predict,
inputs,
outputs = 'label',
title=title,
description=description,
examples=examples,
theme="huggingface",
).launch(debug=True, enable_queue=True)
model = AQC_NET(pretrain=True, num_label=3)
if not os.path.exists('weight.pth'):
print("weight.pth does not exist. Downloading...")
get_file("https://github.com/Kaldr4/EEE-199/releases/download/v1/weight.pth", 'weight.pth',"weight.pth")
print("weight.pth downloaded")
else:
print('Specified file (weight.pth) already downloaded. Skipping this step.')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict = torch.load('weight.pth', map_location=torch.device(device))
model.load_state_dict(state_dict)
run_gradio()