Spaces:
Sleeping
Sleeping
from model import AQC_NET | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as T | |
from PIL import Image | |
import numpy as np | |
import gradio as gr | |
import requests | |
import os | |
def get_file(url,path,filename, chunk_size=128): | |
r = requests.get(url, stream=True) | |
with open(path, 'wb') as downloaded: | |
for chunk in r.iter_content(chunk_size=chunk_size): | |
downloaded.write(chunk) | |
def predict(image_name): | |
labels = {0:"1-20", 1: "21-40" , 2: "41 and above"} | |
model.eval() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
inputs = preprocess(image_name) | |
inputs = inputs.to(device) | |
with torch.no_grad(): | |
outputs = model(inputs.unsqueeze(0)) | |
values, indices = torch.topk(outputs, k=3) | |
print(values,indices) | |
return {labels[i.item()]: v.item() for i, v in zip(indices[0], values.detach()[0])} | |
def preprocess(image_name): | |
transforms = T.Compose([ | |
T.Resize((256,256)), | |
T.ToTensor(), | |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
image = transforms(image_name) | |
return image | |
def run_gradio(): | |
title = "AQC_NET PH - EEE199 Student Project" | |
description = "AQC_NET PH is an image-based deep learning model finetuned on a data-set created in the National Capital Region of Philippines using a Nova PM SDS011 Sensor" | |
examples = ["test_img.jpg","test_img2.jpg","test_img3.jpg"] | |
inputs = [ | |
gr.inputs.Image(type="pil", label="Input Image") | |
] | |
gr.Interface( | |
predict, | |
inputs, | |
outputs = 'label', | |
title=title, | |
description=description, | |
examples=examples, | |
theme="huggingface", | |
).launch(debug=True, enable_queue=True) | |
model = AQC_NET(pretrain=True, num_label=3) | |
if not os.path.exists('weight.pth'): | |
print("weight.pth does not exist. Downloading...") | |
get_file("https://github.com/Kaldr4/EEE-199/releases/download/v1/weight.pth", 'weight.pth',"weight.pth") | |
print("weight.pth downloaded") | |
else: | |
print('Specified file (weight.pth) already downloaded. Skipping this step.') | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
state_dict = torch.load('weight.pth', map_location=torch.device(device)) | |
model.load_state_dict(state_dict) | |
run_gradio() |